In [14]:
from mica_text_coref.coref.seq_coref import data
from mica_text_coref.coref.seq_coref import data_util
from mica_text_coref.coref.seq_coref import print_document

import numpy as np
import random
import re
import tqdm
from transformers import BertTokenizer, RobertaTokenizer, LongformerTokenizer
from typing import Callable

In [6]:
test_corpus = data.CorefCorpus("/home/sbaruah_usc_edu/mica_text_coref/data/"
                               "conll-2012/gold/test.english.jsonlines")

In [7]:
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
roberta_tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
longformer_tokenizer = LongformerTokenizer.from_pretrained(
    "allenai/longformer-base-4096")
tokenizers = [bert_tokenizer, roberta_tokenizer]

In [8]:
for document in random.sample(test_corpus.documents, 1):
    sentences = document.sentences[:2]
    concatenated_sentences = [" ".join(sentence) for sentence in sentences]
    concatenated_text = " ".join([token for sentence in sentences
                                        for token in sentence])
    print(f"Original Tokens\n{sentences}")
    print(f"Concatenated Sentence\n{concatenated_sentences}")
    print(f"Concatenated Text\n{concatenated_text}\n")

    for tokenizer in tokenizers:
        concat_sent_input_ids = tokenizer(concatenated_sentences)["input_ids"]
        concat_sent_tokens = [tokenizer.convert_ids_to_tokens(input_ids)
                                for input_ids in concat_sent_input_ids]
        concat_text_input_ids = tokenizer(concatenated_text)["input_ids"]
        concat_text_tokens = tokenizer.convert_ids_to_tokens(
            concat_text_input_ids)
        tokenizer_name = tokenizer.name_or_path

        print(f"{tokenizer_name}(Concatenated Sentence)")
        print(concat_sent_tokens)
        print(concat_sent_input_ids)

        print(f"{tokenizer_name}(Concatenated Text)")
        print(concat_text_tokens)
        print(concat_text_input_ids)
        print()

    print()

Original Tokens
[['Of', 'all', 'the', 'ethnic', 'tensions', 'in', 'America', ',', 'which', 'is', 'the', 'most', 'troublesome', 'right', 'now', '?'], ['A', 'good', 'bet', 'would', 'be', 'the', 'tension', 'between', 'blacks', 'and', 'Jews', 'in', 'New', 'York', 'City', '.']]
Concatenated Sentence
['Of all the ethnic tensions in America , which is the most troublesome right now ?', 'A good bet would be the tension between blacks and Jews in New York City .']
Concatenated Text
Of all the ethnic tensions in America , which is the most troublesome right now ? A good bet would be the tension between blacks and Jews in New York City .

bert-base-cased(Concatenated Sentence)
[['[CLS]', 'Of', 'all', 'the', 'ethnic', 'tensions', 'in', 'America', ',', 'which', 'is', 'the', 'most', 'troubles', '##ome', 'right', 'now', '?', '[SEP]'], ['[CLS]', 'A', 'good', 'bet', 'would', 'be', 'the', 'tension', 'between', 'blacks', 'and', 'Jews', 'in', 'New', 'York', 'City', '.', '[SEP]']]
[[101, 2096, 1155, 1103, 

In [12]:
def remap_spans_document_level(
    corpus: data.CorefCorpus, tokenize_fn: Callable[[str], list[str]]) -> (
        data.CorefCorpus):
    """Apply tokenize function at the document level."""
    new_corpus = data.CorefCorpus()

    for document in tqdm.tqdm(corpus.documents):
        new_document = data.CorefDocument()
        words = [word for sentence in document.sentences for word in sentence]
        text = " ".join(words)
        tokens = tokenize_fn(text)
        word_characters = "".join(words)
        token_characters = "".join(tokens)
        word_begin_to_word_character = np.zeros(len(words), dtype=int)
        word_end_to_word_character = np.zeros(len(words), dtype=int)
        word_character_to_token_character = np.zeros(len(word_characters),
                                                    dtype=int)
        token_character_to_token_index = np.zeros(len(token_characters),
                                                dtype=int)
        
        c = 0
        for i, word in enumerate(words):
            word_begin_to_word_character[i] = c
            word_end_to_word_character[i] = c + len(word) - 1
            c += len(word)
        
        i, j = 0, 0
        while i < len(word_characters) and j < len(token_characters):
            if word_characters[i] == token_characters[j]:
                word_character_to_token_character[i] = j
                i += 1
                j += 1
            else:
                j += 1
        
        c = 0
        for i, token in enumerate(tokens):
            token_character_to_token_index[c: c + len(token)] = i
            c += len(token)

        def map_begin(word_begin: int) -> int:
            return token_character_to_token_index[
                    word_character_to_token_character[
                        word_begin_to_word_character[word_begin]]]

        def map_end(word_end: int) -> int:
            return token_character_to_token_index[
                    word_character_to_token_character[
                        word_end_to_word_character[word_end]]]
        
        for cluster in document.clusters:
            new_cluster: set[data.Mention] = set()
            for mention in cluster:
                new_begin = map_begin(mention.begin)
                new_end = map_end(mention.end)
                new_mention = data.Mention(new_begin, new_end)
                new_cluster.add(new_mention)
            new_document.clusters.append(new_cluster)

        for mention, ner_tag in document.named_entities.items():
            new_begin = map_begin(mention.begin)
            new_end = map_end(mention.end)
            new_mention = data.Mention(new_begin, new_end)
            new_document.named_entities[new_mention] = ner_tag
        
        for mention, constituency_tag in document.constituents.items():
            new_begin = map_begin(mention.begin)
            new_end = map_end(mention.end)
            new_mention = data.Mention(new_begin, new_end)
            new_document.constituents[new_mention] = constituency_tag
        
        new_sentences = []
        new_speakers = []
        i, j = 0, 0
        for sentence, speakers in zip(document.sentences, document.speakers):
            n_words = len(sentence)
            end = map_end(i + n_words - 1)
            new_sentence = tokens[j: end + 1]
            i += n_words
            j = end + 1
            new_sentences.append(new_sentence)
            new_speakers.append([speakers[0] for _ in range(len(new_sentence))])
        new_document.sentences = new_sentences
        new_document.speakers = new_speakers

        new_document.doc_key = document.doc_key
        new_corpus.documents.append(new_document)
    
    return new_corpus

In [13]:
roberta_corpus = remap_spans_document_level(test_corpus, 
                                            roberta_tokenizer.tokenize)

100%|██████████| 348/348 [00:02<00:00, 155.47it/s]


In [16]:
i = random.randint(0, len(test_corpus.documents) - 1)
original_document = test_corpus.documents[i]
roberta_document = roberta_corpus.documents[i]

original_desc = print_document.pretty_format_coref_document(original_document)
with open("/home/sbaruah_usc_edu/mica_text_coref/data/temp/original_document.txt", "w") as fw:
    fw.write(original_desc)

roberta_desc = print_document.pretty_format_coref_document(roberta_document)
with open("/home/sbaruah_usc_edu/mica_text_coref/data/temp/roberta_document.txt", "w") as fw:
    fw.write(roberta_desc)