In [64]:
import os
from os import path
import json
import sys
from collections import defaultdict
import random
import glob
from transformers import LongformerTokenizerFast
# sys.path.append("../src/")

# from red_utils.constants import IDX_TO_ELEM_TYPE

## Set input and output directories

In [65]:
# input_file = "/home/shtoshni/Research/events/proc_data/kbp_2015"
# output_dir = "/home/shtoshni/Research/events/data/kbp_2014-2015/bert_html"
input_dir = "/home/shtoshni/Research/litbank_coref/data/ontonotes/ment_singletons_longformer"
base_output_dir = "/home/shtoshni/Research/litbank_coref/data/ontonotes/ontonotes_ment_singletons_html"

train_file = "/home/shtoshni/Research/litbank_coref/data/ontonotes/independent_longformer/train.2048.jsonlines"
tokenizer = LongformerTokenizerFast.from_pretrained('allenai/longformer-large-4096', add_prefix_space=True)
singleton_files = glob.glob(path.join(input_dir, "*.jsonlines"))
print(singleton_files)

if not path.exists(base_output_dir):
    os.makedirs(base_output_dir)

['/home/shtoshni/Research/litbank_coref/data/ontonotes/ment_singletons_longformer/150.jsonlines', '/home/shtoshni/Research/litbank_coref/data/ontonotes/ment_singletons_longformer/90.jsonlines', '/home/shtoshni/Research/litbank_coref/data/ontonotes/ment_singletons_longformer/120.jsonlines', '/home/shtoshni/Research/litbank_coref/data/ontonotes/ment_singletons_longformer/30.jsonlines', '/home/shtoshni/Research/litbank_coref/data/ontonotes/ment_singletons_longformer/60.jsonlines']


In [66]:
def get_k_docs(file, k=10):
    random.seed(20)
    
    doc_keys = set()
    data = []
    with open(file) as f:
        for line in f:
            instance = json.loads(line.strip())
            data.append(instance)
            
    random.shuffle(data)
    data = data[:k]
    data = {instance['doc_key']: instance for instance in data}
    return data
    

## HTML Setup

In [67]:
HTML_START = '<!DOCTYPE html><html lang="en"><head><meta charset="UTF-8"></head><body>'


start_tag_template = '<div style="border:2px; display:inline; border-style: {}; border-color: {}; padding: {}px; padding-right: 3px; padding-left: 3px">'
end_tag = '</div>'

largest_padding = 13
padding_reduction = 3


def get_tag_options(cluster):
    border = 'solid'
    if len(cluster) == 1:
        border = 'dotted'
        
    color = '#0066CC'
        
    return (border, color)



In [68]:
data = get_k_docs(train_file)

for singleton_file in singleton_files:
    subdir_name = path.basename(singleton_file).split(".")[0]
    output_dir = path.join(base_output_dir, subdir_name)
    if not path.exists(output_dir):
        os.makedirs(output_dir)

    html_files = []
    with open(singleton_file) as f:
        doc_key_to_singleton_clusters = json.load(f)
        for doc_key, instance in data.items():
            if doc_key not in doc_key_to_singleton_clusters:
                continue
            bert_seg_idx = []
            doc_list = [] 
            for sentence in instance["sentences"]:
                doc_list.extend(sentence)
                bert_seg_idx.append(len(sentence) + (bert_seg_idx[-1] if len(bert_seg_idx) else 0))
            
            bert_seg_idx = set(bert_seg_idx)
            html_tag_list = {}

            # Get all the entity info
#             from copy import deepcopy
#             clusters = deepcopy(instance["clusters"])
#             clusters.extend(doc_key_to_singleton_clusters[doc_key])

            clusters = doc_key_to_singleton_clusters[doc_key]
            clusters = sorted(clusters, 
                              key=lambda cluster: min([elem[0] for elem in cluster]))
            
            for cluster_idx, cluster in enumerate(clusters):
                cluster =  sorted(cluster, key=lambda ment: ment[0] - ment[1] * 1e-5)
                for mention in cluster:
                    span_start, span_end = mention
                    span_end = span_end + 1  ## Now span_end is not part of the span
                    
                    if span_start not in html_tag_list:
                        html_tag_list[span_start] = defaultdict(list)
                    if span_end not in html_tag_list:
                        html_tag_list[span_end] = defaultdict(list)

                    subscript = str(cluster_idx)
                    
                    tag_options = get_tag_options(cluster)
                    start_tag = start_tag_template.format(
                        *tag_options, 
                        largest_padding - padding_reduction * len(html_tag_list[span_start]['start']))


                    html_tag_list[span_start]['start'].append((start_tag))
                    # Subscript used in end
                    html_tag_list[span_end]['end'].append((span_start, cluster_idx, end_tag, subscript))


            html_string = HTML_START + '<div style="line-height: 3">'
            for token_idx, token in enumerate(doc_list):
                if token_idx in bert_seg_idx:
                    html_string += "\n<br/>"
                    
                if token_idx in html_tag_list:
                    for tag_type in ['end', 'start']:
                        if tag_type == 'end' and (tag_type in html_tag_list[token_idx]):
                            tags = html_tag_list[token_idx]['end']

                            # Sort the tags so as to mimic the stack behavior
                            tags = sorted(tags, key=lambda x: x[0] - x[1] * 1e-5)  # Highest mentions first
                            for _, _, html_tag, subscript in tags:
                                html_string += "<sub>" + subscript + "</sub>" 
                                html_string += html_tag
                                # Since we are deleting the highest indices first, the lower indices are unaffected
                        
                        
                        if tag_type == 'start' and (tag_type in html_tag_list[token_idx]):
                            tags = html_tag_list[token_idx]['start']
                            tags = sorted(tags, key=lambda x: x[1], reverse=True)  # Highest mentions first
                            for html_tag in html_tag_list[token_idx]['start']:
                                html_string += html_tag

                html_string += " " + tokenizer.convert_ids_to_tokens(token)

            html_string += "</div></body></html>"
            html_string = html_string.replace("\n", "\n<br/>")
            html_string = html_string.replace("~", "&lt;")
            html_string = html_string.replace("^", "&gt;")
            
            file_name = f"{len(doc_key_to_singleton_clusters[doc_key])} singletons - " + instance["doc_key"].replace("/", "-") + ".html"
            file_path = path.join(output_dir, file_name)
            html_files.append(file_name)
            with open(file_path, "w") as f:
                f.write(html_string)
                
                
    index_html = HTML_START + '<ol type="1">'

    for html_file in html_files:
        base_name = path.splitext(path.basename(html_file))[0].replace("-", "/")
        index_html += '<li> <a href="{}", target="_blank">'.format(html_file) + base_name + '</a></li>\n'

    index_html += '</ol>\n</body>\n</html>'
    index_file_path = path.join(output_dir, "index.html")
    print(index_file_path)
    with open(path.join(output_dir, "index.html"), "w") as g:
        g.write(index_html)

/home/shtoshni/Research/litbank_coref/data/ontonotes/ontonotes_ment_singletons_html/150/index.html
/home/shtoshni/Research/litbank_coref/data/ontonotes/ontonotes_ment_singletons_html/90/index.html
/home/shtoshni/Research/litbank_coref/data/ontonotes/ontonotes_ment_singletons_html/120/index.html
/home/shtoshni/Research/litbank_coref/data/ontonotes/ontonotes_ment_singletons_html/30/index.html
/home/shtoshni/Research/litbank_coref/data/ontonotes/ontonotes_ment_singletons_html/60/index.html
