In [104]:
import torch
import os
import json
import glob
from collections import defaultdict
from os import path

In [112]:
litbank_dir = "../../lrec2020-coref/data/original/conll"
litbank_files = sorted(glob.glob("{}/*.conll".format(litbank_dir)))

output_dir = "../../litbank_html"
assert(len(litbank_files) == 100)  # 100 documents in LitBank
print(len(litbank_files))

100


### Process CoNLL formatted files to extract mention spans

In [106]:
def get_clusters(story_file):
    story_name = path.basename(story_file)
    with open(story_file) as f:
        all_tokens = []
        token_counter = 0
        num_newline_tokens = 0
        # Maintain list of all spans
        cluster_id_to_spans = defaultdict(list)

        # Maintain active clusters here
        cluster_id_to_active_spans = defaultdict(list)
        
        for line_idx, line in enumerate(f.readlines()):
            line = line.strip()
            if line == "":
                all_tokens.append("\n")
                token_counter += 1
                num_newline_tokens += 1
                
                # No active span crosses the sentence boundary.
                assert(len(cluster_id_to_active_spans) == 0)
            else:
                cols = line.split("\t")
                if len(cols) == 13:
                    # Parse the cluster token. Examples - (38 or 38)|37) or (28|(26)
                    cluster_token = cols[12]
                    all_clusters = cluster_token.split('|')

                    for cluster_str in all_clusters:
                        if cluster_str[0] == '(' and cluster_str[-1] == ')':
                            cluster_idx = int(cluster_str[1:-1])
                            cluster_id_to_spans[cluster_idx].append([token_counter, token_counter])
                        elif cluster_str[0] == '(':
                            cluster_idx = int(cluster_str[1:])
                            cluster_id_to_active_spans[cluster_idx].append(token_counter)
                        elif cluster_str[-1] == ')':
                            cluster_idx = int(cluster_str[:-1])
                            assert (len(cluster_id_to_active_spans[cluster_idx]) > 0)
                            start_idx = cluster_id_to_active_spans[cluster_idx].pop(-1)
                            if len(cluster_id_to_active_spans[cluster_idx]) == 0:
                                del cluster_id_to_active_spans[cluster_idx]
                            cluster_id_to_spans[cluster_idx].append([start_idx, token_counter])
                        else:
                            print("Sweet Glory")
                            break
                    
                
                if len(cols) >= 12:
                    token = cols[3]
                    all_tokens.append(token)
                    token_counter += 1            
                                
    return cluster_id_to_spans, all_tokens, num_newline_tokens
            

### LitBank Stats

In [107]:
total_tokens = 0
total_mentions = 0
total_clusters = 0
singleton_clusters = 0

max_clusters = 0
max_cluster_story = None

story_to_info = {}
for story_file in litbank_files: 
    cluster_id_to_spans, all_tokens, num_newline_tokens = get_clusters(story_file)
    story_to_info[story_file] = (all_tokens, cluster_id_to_spans)

    total_tokens += len(all_tokens) - num_newline_tokens
    total_clusters += len(cluster_id_to_spans)
    
    max_clusters = max(max_clusters, len(cluster_id_to_spans))
    if max_clusters == len(cluster_id_to_spans):
        max_cluster_story = path.basename(story_file)
    
    for cluster_id in cluster_id_to_spans:
        cluster_mentions = len(cluster_id_to_spans[cluster_id])
        total_mentions += cluster_mentions
        if cluster_mentions == 1:
            singleton_clusters += 1

print(f"Total tokens in LitBank: {total_tokens}")
print(f"# of Entity mentions: {total_mentions}, Total # of clusters: {total_clusters}")
print("Max clusters ({}): {}".format(max_cluster_story, max_clusters))
print(f"Fraction of singleton clusters among total clusters: {singleton_clusters/total_clusters:.2f}")


Total tokens in LitBank: 210532
# of Entity mentions: 29103, Total # of clusters: 7927
Max clusters (940_the_last_of_the_mohicans_a_narrative_of_1757_brat.conll): 199
Fraction of singleton clusters among total clusters: 0.73


### HTML Conversion

In [108]:
HTML_START = '<!DOCTYPE html><html lang="en"><head><meta charset="UTF-8"></head><body>'

cluster_start_tag = '<div style="border:2px; display : inline; border-style:solid; padding: {}px; padding-right: 3px; padding-left: 3px">'
singleton_start_tag = '<div style="border:2px; display : inline; border-style:dotted; padding:{}px; padding-right: 3px; padding-left: 3px">'
end_tag = '</div>'

largest_padding = 11
padding_reduction = 2

In [109]:
def return_html(story_file):
    story_name = path.basename(story_file)
    html_string = HTML_START + '<div style="line-height: 3">'
    
    all_tokens, cluster_id_to_spans = story_to_info[story_file]
    
    ment_start_dict = defaultdict(list)
    ment_end_dict = defaultdict(list)
    for cluster_idx, ment_list in cluster_id_to_spans.items():
        for (ment_start, ment_end) in ment_list:
            ment_start_dict[ment_start].append((ment_end, cluster_idx))
            ment_end_dict[ment_end].append((ment_start, cluster_idx))
                        
    # Sort mentions with same mention start by later mention ends i.e. start with spans which are longer
    for ment_start in ment_start_dict:
        ment_start_dict[ment_start] = sorted(ment_start_dict[ment_start], key=lambda x: x[0], reverse=True)
        
    # Sort mentions with same mention end by later mention starts i.e. start with spans which are shorter
    for ment_end in ment_end_dict:
        ment_end_dict[ment_end] = sorted(ment_end_dict[ment_end], key=lambda x: x[0], reverse=True)
        
    active_clusters = 0
    for token_idx, token in enumerate(all_tokens):
        token_added = False
        if token == "\n":
            html_string += "<br/>\n" 
            continue
        if token_idx in ment_start_dict:
            for (_, cluster_idx) in ment_start_dict[token_idx]:
                prefix = cluster_start_tag
                if len(cluster_id_to_spans[cluster_idx]) == 1:
                    prefix = singleton_start_tag
                html_string += prefix.format(largest_padding - active_clusters * padding_reduction)
                active_clusters += 1
            
            html_string += token + " "
            token_added = True
        
        if not token_added:
            html_string += token + " "

        if token_idx in ment_end_dict:
            for (_, cluster_idx) in ment_end_dict[token_idx]:
                html_string += "<sub>" + str(cluster_idx) + "</sub>" + end_tag + " "
                active_clusters -= 1
                assert (active_clusters >= 0)
    
    html_string += "</div></body></html>"
    return html_string

### Process all files to get HTML version of the data

In [110]:
def extract_book_name(story_file):
    conll_file = path.basename(story_file)
    prefix = conll_file.split(".")[0]
    prefix_words = prefix.split("_")[1:-1]
    book_name = (" ".join(prefix_words)).capitalize()
    
    return book_name


index_html = HTML_START + '<ol type="1">'
for story_file in litbank_files:
    base_file = path.basename(story_file)
    output_file = base_file.replace("conll", "html")
    output_file = output_file.replace("_brat", "")
    
    book_name = extract_book_name(story_file)
    index_html += '<li> <a href="{}", target="_blank">'.format(output_file) + book_name + '</a></li>\n'
    
    book_html = return_html(story_file)
    with open(path.join(output_dir, output_file), "w") as f:
        f.write(book_html)
        

index_html += '</ol>\n</body>\n</html>'
output_file = path.join(output_dir, "index.html")
print(output_file)
with open(output_file, "w") as g:
    g.write(index_html)

../../litbank_html/index.html
