In [1]:
import json
from os import path
import os
import sys
import glob
sys.path.append("/home/shtoshni/Research/long-doc-coref/src")

from transformers import BertTokenizerFast
from inference.inference import Inference
from inference.tokenize_doc import DocumentState, split_into_segments

from coref_utils.utils import get_mention_to_cluster_idx

In [2]:
input_dir = "/home/shtoshni/Research/litbank_coref/data/ontonotes/independent"
input_files = glob.glob(path.join(input_dir, "*.jsonlines"))
input_files = [filename for filename in input_files if '.512' in filename]
print(input_files)

output_dir = "/home/shtoshni/Research/litbank_coref/data/ontonotes/independent_singletons"
if not path.exists(output_dir):
    os.makedirs(output_dir)
model_loc = "/home/shtoshni/Research/long-doc-coref/models/umem_singleton_round_1/model.pth"

tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')

['/home/shtoshni/Research/litbank_coref/data/ontonotes/independent/test.512.jsonlines', '/home/shtoshni/Research/litbank_coref/data/ontonotes/independent/dev.512.jsonlines', '/home/shtoshni/Research/litbank_coref/data/ontonotes/independent/train.512.jsonlines']


In [3]:
model = Inference(model_loc)
# model = Inference(model_loc)

{'base_data_dir': '../data/', 'data_dir': '/share/data/speech/shtoshni/research/litbank_coref/data/ontonotes/overlap_singletons', 'base_model_dir': '/share/data/speech/shtoshni/research/litbank_coref/models', 'model_dir': '/share/data/speech/shtoshni/research/litbank_coref/models/coref_ontonotes_0f09c611ab1ac7bc2d56b8a01be6a763', 'dataset': 'ontonotes', 'conll_scorer': '/share/data/speech/shtoshni/research/litbank_coref/reference-coreference-scorers/scorer.pl', 'model_size': 'large', 'doc_enc': 'overlap', 'pretrained_bert_dir': '/share/data/speech/shtoshni/resources', 'max_segment_len': 512, 'max_span_width': 30, 'ment_emb': 'attn', 'use_gold_ments': False, 'top_span_ratio': 0.4, 'mem_type': 'unbounded', 'max_ents': None, 'eval_max_ents': None, 'mlp_size': 3000, 'mlp_depth': 1, 'entity_rep': 'wt_avg', 'emb_size': 20, 'cross_val_split': 0, 'use_curriculum': False, 'new_ent_wt': 1.0, 'num_train_docs': None, 'num_eval_docs': None, 'max_training_segments': 5, 'sample_invalid': 1.0, 'dropou

In [4]:
def get_sentences(tokens, sentence_map, subtoken_map):
    token_map = []
    last_sentence_idx = -1
    for subtoken_idx, sentence_idx in zip(subtoken_map, sentence_map):
        if sentence_idx != last_sentence_idx:
            token_map.append(subtoken_idx)
            last_sentence_idx = sentence_idx
        else:
            token_map[-1] = subtoken_idx
    
    last_token_idx = 0
    sentences = []
    for token_idx in token_map:
        sentence = tokens[last_token_idx: token_idx + 1]
        if len(sentence):
            sentences.append(sentence)
        last_token_idx = token_idx + 1
    return sentences


In [None]:
for input_file in input_files:
    output_file = path.join(output_dir, path.basename(input_file))
    
    count = 0
    with open(input_file) as input_f, open(output_file, "w") as output_f:
        for line in input_f:
            instance = json.loads(line.strip())
            mention_to_cluster_dict = get_mention_to_cluster_idx(instance["clusters"])
            word_offset = 0
            singleton_clusters = []
            for sent_idx, sentence in enumerate(instance["sentences"]):
                document_state = DocumentState()
                document_state.subtokens = sentence + sentence
                subtoken_map = instance['subtoken_map'][word_offset: word_offset + len(sentence)]
                document_state.subtoken_map = (subtoken_map + 
                                               [max(subtoken_map) + 1 + tmp_idx for tmp_idx in subtoken_map])
                document_state.sentence_end = [False] * len(document_state.subtokens)
                
                token_end = []
                last_idx = document_state.subtokens[0]
                for subtoken_idx in document_state.subtokens:
                    if subtoken_idx != last_idx and len(token_end):
                        token_end[-1] = True
                    
                    token_end.append(False)
                
                token_end[-1] = True
                document_state.token_end = token_end
                
                split_into_segments(document_state, document_state.sentence_end, document_state.token_end)
                document = document_state.finalize()
                
                output_dict = model.perform_coreference(document, doc_key=instance["doc_key"])
                mod_len = len(output_dict['tokenized_doc']['subtoken_map'])
                
                assert (mod_len % 2 == 0)
                orig_len = mod_len // 2
                                
                clusters = [cluster for cluster in output_dict['subtoken_idx_clusters'] if len(cluster) == 2]
                for cluster in clusters:
                    cluster = sorted(cluster, key=lambda x: x[0])
                    ment1, ment2 = cluster
                        
                    if ment2[0] > orig_len:
                        ment2 = (ment2[0] - orig_len, ment2[1] - orig_len)
                        if ment1 == ment2:
                            offset_corrected_ment = (ment1[0] + word_offset, ment1[1] + word_offset)
                            if offset_corrected_ment not in mention_to_cluster_dict:
                                singleton_clusters.append([offset_corrected_ment])

                word_offset += len(sentence)
                    
            instance["clusters"].extend(singleton_clusters)
            output_f.write(json.dumps(instance) + "\n")
            