In [57]:
import torch
import os
import json
import glob
from collections import defaultdict
from os import path

In [58]:
input_file = "/home/shtoshni/valid.log.jsonl"
output_dir = path.join(path.dirname(input_file), "output_logs")

if not path.exists(output_dir):
    os.makedirs(output_dir)

In [59]:
HTML_START = '<!DOCTYPE html><html lang="en"><head><meta charset="UTF-8"></head><body>'


gt_tag = '<div style="border:2px; display:inline; border-style: solid; border-color: green; padding: {}px; padding-right: 3px; padding-left: 3px">'
gt_tag_single = '<div style="border:2px; display:inline; border-style: dotted; border-color: green; padding: {}px; padding-right: 3px; padding-left: 3px">'

pred_tag = '<div style="border:2px; display:inline; border-style: solid; border-color: #FF69B4; padding: {}px; padding-right: 3px; padding-left: 3px">'
pred_tag_single = '<div style="border:2px; display:inline; border-style: dotted; border-color: #FF69B4; padding: {}px; padding-right: 3px; padding-left: 3px">'


end_tag = '</div>'

largest_padding = 14
padding_reduction = 2.5

In [60]:
with open(input_file) as f:
    for line in f:
        instance = json.loads(line.strip())
        doc_name = instance["doc_key"]
        pred_clusters = instance["predicted_clusters"]
        gt_clusters = instance["clusters"]
        
        document = []
        for sentence in instance["sentences"]:
            sentence_mod = sentence[:-1] + [sentence[-1] + "\n"]
            document.extend(sentence_mod)
            
        html_tag_list = {}
        for cluster_idx, clusters in enumerate(pred_clusters):
            for (span_start, span_end) in clusters:
                start_tag = pred_tag
                if len(clusters) == 1:
                    start_tag = pred_tag_single
                if span_start not in html_tag_list:
                    html_tag_list[span_start] = defaultdict(list)
                if span_end not in html_tag_list:
                    html_tag_list[span_end] = defaultdict(list)

                subscript = 'pred ' + str(cluster_idx)
                
                mention = (span_end, "pred")
                html_tag_list[span_start]['start'].append((mention, start_tag, ''))
                # Subscript used in end
                html_tag_list[span_end]['end'].append((mention, end_tag, subscript))
                
        for cluster_idx, clusters in enumerate(gt_clusters):
            for (span_start, span_end) in clusters:
                start_tag = gt_tag
                if len(clusters) == 1:
                    start_tag = gt_tag_single
                if span_start not in html_tag_list:
                    html_tag_list[span_start] = defaultdict(list)
                if span_end not in html_tag_list:
                    html_tag_list[span_end] = defaultdict(list)

                subscript = 'gt ' + str(cluster_idx)
                
                mention = (span_end, "gt")
                html_tag_list[span_start]['start'].append((mention, start_tag, ''))
                # Subscript used in end
                html_tag_list[span_end]['end'].append((mention, end_tag, subscript))

        
        html_string = HTML_START + '<div style="line-height: 3">'
        mentions_processed = []
        for idx, token in enumerate(document):
            if idx in html_tag_list:
                for tag_type in ['end', 'start']:
                    if tag_type == 'end' and (tag_type in html_tag_list[idx]):
                        tags = html_tag_list[idx]['end']

                        tags = [(mentions_processed.index(mention), html_tag, cluster_idx) 
                               for mention, html_tag, cluster_idx in tags]
                        # Sort the tags so as to mimic the stack behavior
                        tags = sorted(tags, key=lambda x: x[0], reverse=True)  # Highest mentions first
                        for mention_idx, html_tag, cluster_info in tags:
                            html_string += "<sub>" + cluster_info + "</sub>" 
                            html_string += html_tag
                            # Since we are deleting the highest indices first, the lower indices are unaffected
                            del mentions_processed[mention_idx]

                    if tag_type == 'start' and (tag_type in html_tag_list[idx]):
                        tags = html_tag_list[idx]['start']
                        tags = sorted(tags, key=lambda x: x[0][0], reverse=True)  # Mentions ending last are first
                        for mention_id, html_tag, cluster_idx in tags:
                            # Add the mention_id to the current list of active mentions
                            mentions_processed.append(mention_id)
                            padding_val = max(2, largest_padding - len(mentions_processed) * padding_reduction)
                            fmted_tag = html_tag.format(padding_val)
#                             print(fmted_tag)
                            html_string += fmted_tag
            html_string += token + " "

        html_string += "</div></body></html>"
        html_string = html_string.replace("\n", "\n<br/>")
        with open(path.join(output_dir, doc_name + ".html"), "w") as f:
            f.write(html_string)
            
#         break