In [5]:
import os
from os import path
import json
import sys

sys.path.append("../src/")

from red_utils.constants import IDX_TO_ELEM_TYPE

## Set input and output directories

In [3]:
input_file = "/home/shtoshni/Research/events/proc_data/red/independent/"
suffix = "{}.512.jsonlines"

splits = ["train", "dev", "test"]

output_dir = "/home/shtoshni/Research/events/data/red/bert_html"

if not path.exists(output_dir):
    os.makedirs(output_dir)

## HTML Setup

In [4]:
HTML_START = '<!DOCTYPE html><html lang="en"><head><meta charset="UTF-8"></head><body>'


start_tag_template = '<div style="border:2px; display:inline; border-style: solid; border-color: {}; padding: 10px; padding-right: 3px; padding-left: 3px">'

entity_tag = start_tag_template.format('#0066CC')
event_tag = start_tag_template.format('violet')

end_tag = '</div>'


In [None]:
for split in splits:
    print(f"Processing {split.capitalize()}")
    # Read the source doc
    split_file = path.join(input_file, suffix.format(split))
    
    with open(split_file) as f:
        for line in f:
            instance = json.loads(line.strip())
            
            doc = [] 
            for sentence in instance["sentences"]:
                doc.extend(sentence)
    
            html_tag_list = {}

            # Get all the entity info
            for cluster in instance["clusters"]:
                for mention in cluster:
                    span_start, span_end, ent_type = mention
                    ent_type = IDX_TO_ELEM_TYPE[ent_type]
                    if ent_type == 'ENTITY':
                        start_tag = entity_tag
                    elif ent_type == 'EVENT':
                        start_tag = event_tag

                    if span_start not in html_tag_list:
                        html_tag_list[span_start] = defaultdict(list)
                    if span_end not in html_tag_list:
                        html_tag_list[span_end] = defaultdict(list)

                    subscript = ''
                    for idx, (cluster_type, cluster_idx) in enumerate(cluster_info):
                        if idx > 0:
                            subscript += ", "
                        subscript += cluster_type + " " + str(cluster_idx)


                html_tag_list[span_start]['start'].append((mention, start_tag, ''))
                # Subscript used in end
                html_tag_list[span_end]['end'].append((mention, end_tag, subscript))


            html_string = HTML_START + '<div style="line-height: 3">'

            offset = 0 
            counter = 0
            source_str = source_str.replace("<", "~")
            source_str = source_str.replace(">", "^")

            # This list acts like a stack. We push the new mentions based on start tag
            # and remove the mentions in the order of most recent to least recent.
            mentions_processed = []

            for idx, token in enumerate(source_str):
                if idx in html_tag_list:
                    for tag_type in ['end', 'start']:
                        if tag_type == 'end' and (tag_type in html_tag_list[idx]):
                            tags = html_tag_list[idx]['end']

                            tags = [(mentions_processed.index(mention), html_tag, cluster_idx) 
                                   for mention, html_tag, cluster_idx in tags]
                            # Sort the tags so as to mimic the stack behavior
                            tags = sorted(tags, key=lambda x: x[0], reverse=True)  # Highest mentions first
                            for mention_idx, html_tag, cluster_info in tags:
                                html_string += "<sub>" + cluster_info + "</sub>" 
                                html_string += html_tag
                                # Since we are deleting the highest indices first, the lower indices are unaffected
                                del mentions_processed[mention_idx]

                        if tag_type == 'start' and (tag_type in html_tag_list[idx]):
                            for mention_id, html_tag, cluster_idx in html_tag_list[idx]['start']:
                                # Add the mention_id to the current list of active mentions
                                mentions_processed.append(mention_id)
                                html_string += html_tag

                html_string += token

            html_string += "</div></body></html>"
            html_string = html_string.replace("\n", "\n<br/>")
            html_string = html_string.replace("~", "&lt;")
            html_string = html_string.replace("^", "&gt;")
            with open(path.join(output_dir, base_name + ".html"), "w") as f:
                f.write(html_string)