In [16]:
import torch
import os
import json
import glob
from collections import defaultdict, OrderedDict
from os import path
from transformers import BertTokenizerFast, BertTokenizer

In [17]:
# sample_file = "/home/shtoshni/1023_bleak_house_brat.conll.txt"
litbank_dir = "/home/shtoshni/Research/litbank_coref/litbank/coref/tsv"
ann_files = glob.glob("{}/*.ann".format(litbank_dir))
ann_files.sort()
story_files = [ann_file.replace('.ann', '.txt') for ann_file in ann_files]

output_dir = "/home/shtoshni/Research/litbank_coref/data"
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', add_special_tokens=False)

sos_id = tokenizer.cls_token_id
eos_id = tokenizer.sep_token_id

max_segment_len = 350

In [18]:
proc_stories = []
story_list = []
for story_file, ann_file in zip(story_files, ann_files):
    # The dictionary in which all the info is stored
    proc_story = OrderedDict()
    # Story name - Used as the key to doc
    story_name = path.basename(story_file).replace(".txt", "")
    story_list.append(path.basename(story_file))
    proc_story["doc_key"] = story_name
    proc_story["sentences"] = []
    
    # Load the story in a dictionary
    with open(story_file) as f:
        line_counter = 0
        
        # Doc level variables - Window of tokens
        windows = []
        subtoken_map = []
        sent_map = []
        
        sent_counter = 0
        token_counter = 0
        
        # Local variables
        cur_window = []
        
        sent_to_subword_map = []
        for line in f:
            if len(cur_window) == 0:
                # The old window of tokens has been released; starting with a fresh window of subtokens
                sent_tokens = [sos_id]
                sent_subtoken_map = [token_counter]
                # word idx to subword idx mapping
                word_to_subword_map = [0]
            else:
                sent_tokens = []
                sent_subtoken_map = []
                # word idx to subword idx mapping
                word_to_subword_map = []
            
            words = line.strip().split(" ")
            for word_idx, word in enumerate(words):
                token_ids = tokenizer.encode(word, add_special_tokens=False)
                sent_tokens.extend(token_ids)

                # Increase the sentence mapping
                # Add the token to subtoken map
                sent_subtoken_map.extend([token_counter] * len(token_ids))
                word_to_subword_map.extend([word_idx] * len(token_ids))

                token_counter += 1
                
            window_tokens = len(cur_window) + len(sent_tokens)
            if window_tokens <= (max_segment_len - 1):  # reserving 1 for [SEP]
                # The current sentence fits in the window
                cur_window.extend(sent_tokens)
                sent_map.extend([sent_counter] * len(sent_tokens))
                subtoken_map.extend(sent_subtoken_map)
                sent_to_subword_map.append(word_to_subword_map)
                
            else:
                # Means that the current window wasn't empty.
                # Need to add EOS ID to the last window before initializing the new window
                cur_window.append(eos_id)
                sent_map.append(sent_counter - 1)

                # Repeat the word/token idx for EOS token
                sent_to_subword_map[-1].append(sent_to_subword_map[-1][-1])
                subtoken_map.append(subtoken_map[-1])

                # Put the current window to the list of BERT-sentences/windows
                proc_story["sentences"].append(tokenizer.convert_ids_to_tokens(cur_window))

                # Add [CLS] and the current sentence
                cur_window = [sos_id] + sent_tokens
                sent_map.extend([sent_counter] * (len(sent_tokens) + 1))
                subtoken_map.extend([sent_subtoken_map[0]] + sent_subtoken_map)
                sent_to_subword_map.append([0] + word_to_subword_map)
                
                if (len(word_to_subword_map) + 1) > max_segment_len:
                    ### TODO: Handle a sentence longer than the max_segment_len
                    print("Sweet Glory: %d", len(word_to_subword_map) + 1)
                    print(story_name)
                
            sent_counter += 1
        
        if cur_window:
            cur_window.append(eos_id)
            sent_map.append(sent_counter - 1)

            # Repeat the word/token idx for EOS token
            sent_to_subword_map[-1].append(sent_to_subword_map[-1][-1])
            subtoken_map.append(subtoken_map[-1])
            proc_story["sentences"].append(tokenizer.convert_ids_to_tokens(cur_window))
            
        proc_story["sent_map"] = sent_map
        proc_story["subtoken_map"] = subtoken_map
        
        proc_story["tokenized_sentences"] = []
        proc_story["tokenized_doc"] = []
        for sentence in proc_story["sentences"]:
            sent_tokens = tokenizer.convert_tokens_to_ids(sentence)
            proc_story["tokenized_sentences"].append(sent_tokens)
            proc_story["tokenized_doc"].extend(sent_tokens)
            
        sent_idx_to_subword_offset = [0]
        subword_counter = 0
        for idx, word_to_subword_map in enumerate(sent_to_subword_map):
            num_subwords = len(word_to_subword_map)
            subword_counter += num_subwords
            sent_idx_to_subword_offset.append(subword_counter)
        
        assert(subword_counter == len(proc_story["subtoken_map"]))       
        assert(sum([len(sentence) for sentence in  proc_story["sentences"]]) == len(proc_story["subtoken_map"]))
        
        # Get the cluster information
        with open(ann_file) as f:
            mention_dict = {}
            for line in f:
                cols = line.strip().split("\t")
                if cols[0] == 'MENTION':
                    mention_id = cols[1]
                    start_line, start_word_idx = int(cols[2]), int(cols[3])
                    # Sentence offset
                    start_line_offset = sent_idx_to_subword_offset[start_line]
                    
                    # Map the start word index to subword start index
                    start_subword_idx = (start_line_offset + 
                                         sent_to_subword_map[start_line].index(start_word_idx))
                    
                    end_line, end_word_idx = int(cols[4]), int(cols[5])
                    # Sentence offset
                    end_line_offset = sent_idx_to_subword_offset[end_line]                    
                    # Map the end word index to subword start index - Search for the last subword corresponding to it.
                    
                    end_word_start_idx = sent_to_subword_map[start_line].index(end_word_idx)
                    end_word_idx_count = sent_to_subword_map[start_line].count(end_word_idx)
                    end_subword_idx = (end_line_offset + end_word_start_idx + end_word_idx_count)
                    
                    given_mention_str = cols[6]
                    ent_type, mention_type = cols[7], cols[8]

                    assert(start_line == end_line)  # Check that no mention is across a line
                    mention_dict[mention_id] = [start_subword_idx, end_subword_idx]
    
        with open(ann_file) as f:
            coref_chains = OrderedDict()
            for line in f:
                cols = line.strip().split("\t")
                if cols[0] == 'COREF':
                    mention_id, cluster_id = cols[1], cols[2]
                    if not cluster_id in coref_chains:
                        coref_chains[cluster_id] = [mention_dict[mention_id]]
                    else:
                        coref_chains[cluster_id].append(mention_dict[mention_id])

        proc_story["clusters"] = [coref_chains[cluster_id] for cluster_id in coref_chains]
        proc_stories.append(proc_story)

Sweet Glory: %d 350
521_the_life_and_adventures_of_robinson_crusoe_brat
Sweet Glory: %d 331
521_the_life_and_adventures_of_robinson_crusoe_brat


In [19]:
# train_set = (0, 80)
# dev_set = (80, 88)
# test_set = (88, 96)
# final_set = (96, 100)

# for (start_idx, end_idx), split in zip([train_set, dev_set, test_set, final_set], 
#                                        ["train", "valid", "test", "final"]):
#     with open(path.join(output_dir, split + ".{}.jsonl".format(max_segment_len)), "w") as f:
#         for instance in proc_stories[start_idx: end_idx]:
#             f.write(json.dumps(instance) + "\n")
            
#     with open(path.join(output_dir, split + ".stories.txt"), "w") as g:
#         for story in story_list[start_idx: end_idx]:
#             g.write(story + "\n")    


In [20]:
with open(path.join(output_dir, "all.{}.jsonl".format(max_segment_len)), "w") as f:
    for instance in proc_stories:
        f.write(json.dumps(instance) + "\n")