In [102]:
import os
from os import path
import json
import sys
from collections import defaultdict

sys.path.append("/home/shtoshni/Research/events/src/")
from kbp_2015_utils.constants import EVENT_SUBTYPES

# from red_utils.constants import IDX_TO_ELEM_TYPE

## Set input and output directories

In [103]:
# input_file = "/home/shtoshni/Research/events/proc_data/kbp_2015"
# output_dir = "/home/shtoshni/Research/events/data/kbp_2014-2015/bert_html"
input_file = "/home/shtoshni/Research/events/models/ment_kbp_2015_cleaned_mlp_1000_model_base_emb_attn_type_spanbert_segments_3_width_4_ft/valid.log.jsonl"
base_output_dir = "/home/shtoshni/Research/events/models/html/ment_detection"

model_name = path.basename(path.dirname(input_file))
output_dir = path.join(base_output_dir, model_name)
if not path.exists(output_dir):
    os.makedirs(output_dir)

## HTML Setup

In [104]:
HTML_START = '<!DOCTYPE html><html lang="en"><head><meta charset="UTF-8"></head><body>'


start_tag_template = '<div style="border:2px; display:inline; border-style: {}; border-color: {}; padding: {}px; padding-right: 3px; padding-left: 3px">'
end_tag = '</div>'

largest_padding = 14
padding_reduction = 3


def get_tag_options(tag_type="gt"):
    border = 'solid'
        
    color = '#FFD700'
    if tag_type == "pred":
        color = 'red'
        
    return (border, color)



In [105]:
html_files = []


with open(input_file) as f:
    for line in f:
        instance = json.loads(line.strip())

        bert_seg_idx = []
        doc_list = [] 
        for sentence in instance["sentences"]:
            doc_list.extend(sentence)
            bert_seg_idx.append(len(sentence) + (bert_seg_idx[-1] if len(bert_seg_idx) else 0))

        bert_seg_idx = set(bert_seg_idx)
        html_tag_list = {}

        # Get all the entity info
        for mention_type in ["gt_mentions", "pred_mentions"]:
            tag_type = mention_type.split("_")[0]
            mentions = sorted(instance[mention_type], key=lambda x: x[0] + 1e-4 * x[1] + 1e-6 * x[2])
            for mention_idx, mention in enumerate(mentions):
                if "gt_" in mention_type:
                    span_start, span_end, event_subtype_val = mention
                    ent_type = f'{EVENT_SUBTYPES[event_subtype_val]} {mention_idx}'
                else:
                    span_start, span_end, event_subtype_val = mention
                    ent_type = f'{EVENT_SUBTYPES[event_subtype_val]} {mention_idx}'

                span_end = span_end + 1  ## Now span_end is not part of the span
                
                if span_start not in html_tag_list:
                    html_tag_list[span_start] = defaultdict(list)
                if span_end not in html_tag_list:
                    html_tag_list[span_end] = defaultdict(list)

                subscript = ent_type

                tag_options = get_tag_options(tag_type)
                start_tag = start_tag_template.format(*tag_options, 
                                                      largest_padding - padding_reduction * len(html_tag_list[span_start]['start']))


                html_tag_list[span_start]['start'].append((start_tag))
                # Subscript used in end
                html_tag_list[span_end]['end'].append((span_start, mention_idx, end_tag, subscript))


        html_string = HTML_START + '<div style="line-height: 3">'
        for token_idx, token in enumerate(doc_list):
            if token_idx in bert_seg_idx:
                html_string += "\n<br/>"

            if token_idx in html_tag_list:
                for tag_type in ['end', 'start']:
                    if tag_type == 'end' and (tag_type in html_tag_list[token_idx]):
                        tags = html_tag_list[token_idx]['end']

                        # Sort the tags so as to mimic the stack behavior
                        tags = tags[::-1]  # Highest mentions first
                        for _, _, html_tag, subscript in tags:
                            html_string += "<sub>" + subscript + "</sub>" 
                            html_string += html_tag
                            # Since we are deleting the highest indices first, the lower indices are unaffected

                    if tag_type == 'start' and (tag_type in html_tag_list[token_idx]):
                        for html_tag in html_tag_list[token_idx]['start']:
                            html_string += html_tag

            html_string += " " + token

        html_string += "</div></body></html>"
        html_string = html_string.replace("\n", "\n<br/>")
        html_string = html_string.replace("~", "&lt;")
        html_string = html_string.replace("^", "&gt;")

        file_name = instance["doc_key"].replace("/", "-") + f"- ({instance['f_score']})" + ".html"
#         print(file_name)
        file_path = path.join(output_dir, file_name)
        html_files.append(file_name)
        with open(file_path, "w") as f:
            f.write(html_string)

In [106]:
index_html = HTML_START + '<ol type="1">'

for html_file in html_files:
    base_name = path.splitext(path.basename(html_file))[0].replace("-", "/")
    index_html += '<li> <a href="{}", target="_blank">'.format(html_file) + base_name + '</a></li>\n'
    
index_html += '</ol>\n</body>\n</html>'
index_file_path = path.join(output_dir, "index.html")
print(index_file_path)
with open(path.join(output_dir, "index.html"), "w") as g:
    g.write(index_html)

/home/shtoshni/Research/events/models/html/ment_detection/ment_kbp_2015_cleaned_mlp_1000_model_base_emb_attn_type_spanbert_segments_3_width_4_ft/index.html
