In [1]:
from mica_text_coref.coref.seq_coref import data
from mica_text_coref.coref.seq_coref import data_util
from mica_text_coref.coref.seq_coref import representatives
from mica_text_coref.coref.seq_coref import tensorize
from mica_text_coref.coref.seq_coref import print_document

from keras_preprocessing import sequence
import numpy as np
import re
import torch
from torch import nn
from torch.utils import data as tdata
import tqdm
from transformers import LongformerTokenizer, LongformerModel
from typing import Callable
import unidecode

In [2]:
class CorefLongformerModel(nn.Module):
    """Coreference Resolution Model for English using the Longformer model.
    """

    def __init__(self, use_large: bool = False) -> None:
        super().__init__()

        model_size = "large" if use_large else "base"
        self.tokenizer: LongformerTokenizer = (
            LongformerTokenizer.from_pretrained(
                f"allenai/longformer-{model_size}-4096"))
        self.longformer: LongformerModel = LongformerModel.from_pretrained(
            f"allenai/longformer-{model_size}-4096")

        self.longformer_hidden_size: int = self.longformer.config.hidden_size
        self.n_labels = 3
        self.label_embedding_size = 10
        self.label_embedding = nn.Embedding(self.n_labels,
                                            self.label_embedding_size)
        self.gru_hidden_size = self.longformer_hidden_size
        self.gru = nn.GRU(
            self.longformer_hidden_size + self.label_embedding_size,
            self.gru_hidden_size, bidirectional=True)
        self.token_classifier = nn.Linear(self.gru_hidden_size, self.n_labels)

In [3]:
model = CorefLongformerModel()

Some weights of the model checkpoint at allenai/longformer-base-4096 were not used when initializing LongformerModel: ['lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing LongformerModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LongformerModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [16]:
def create_tensors(
    corpus: data.CorefCorpus, 
    representative_mentions: list[list[data.Mention]],
    coref_longformer_model: CorefLongformerModel) -> (
        tdata.TensorDataset):
    """Create Tensor Dataset from coreference corpus and representative mentions
     for each document's cluster. The number of representative
    mentions should equal the number of clusters. Make sure that you remap
    spans of the document before passing it here.

    Args:
        corpus: Coreference corpus.
        representative_mentions: List of list of data.Mention objects.
        coref_longformer_model: Coreference Longformer Model.
    
    Returns:
        A tensor pytorch dataset. It contains the following tensors: 
            1. token ids: LongTensor
            2. mention ids: IntTensor
            3. label ids: IntTensor
            4. attention mask: FloatTensor
            5. global attention mask: FloatTensor
            6. doc ids: IntTensor
    """
    assert len(corpus.documents) == len(representative_mentions), (
        "Number of documents should equal the number of representative mention"
        " lists")
    for document, mentions in zip(corpus.documents, representative_mentions):
        assert len(document.clusters) == len(mentions), f"Number of clusters"
        " should equal the number of representative mentions,"
        " doc key = {document.doc_key}"

    max_sequence_length = 0
    for document in corpus.documents:
        n_tokens = sum(len(sentence) for sentence in document.sentences)
        max_sequence_length = max(n_tokens, max_sequence_length)
    max_sequence_length = min(
        coref_longformer_model.longformer.config.max_position_embeddings, 
        max_sequence_length)
    tokenizer = coref_longformer_model.tokenizer

    token_ids_list: list[list[int]] = []
    mention_ids_list: list[list[int]] = []
    label_ids_list: list[list[int]] = []
    attn_mask_list: list[list[int]] = []
    global_attn_mask_list: list[list[int]] = []
    doc_ids: list[int] = []

    for i, document in enumerate(corpus.documents):
        tokens = [token for sentence in document.sentences 
                        for token in sentence]
        token_ids: list[int] = tokenizer.convert_tokens_to_ids(tokens)
        attn_mask: list[int] = [1 for _ in range(len(tokens))]
        doc_id = document.doc_id
        
        for j, cluster in enumerate(document.clusters):
            sorted_cluster = sorted(cluster)
            
            if sorted_cluster[-1].end < max_sequence_length:
                mention = representative_mentions[i][j]
                mention_ids = [0 for _ in range(len(tokens))]
                mention_ids[mention.begin] = 1
                for k in range(mention.begin + 1, mention.end + 1):
                    mention_ids[k] = 2
                label_ids = [0 for _ in range(len(tokens))]
                global_attn_mask = [0 for _ in range(len(tokens))]
                for k in range(mention.begin, mention.end + 1):
                    global_attn_mask[k] = 1
                
                for mention in sorted_cluster:
                    label_ids[mention.begin] = 1
                    for k in range(mention.begin + 1, mention.end + 1):
                        label_ids[k] = 2
                
                token_ids_list.append(token_ids)
                mention_ids_list.append(mention_ids)
                label_ids_list.append(label_ids)
                attn_mask_list.append(attn_mask)
                global_attn_mask_list.append(global_attn_mask)
                doc_ids.append(doc_id)
    
    token_ids_pt = torch.LongTensor(sequence.pad_sequences(token_ids_list, 
        maxlen=max_sequence_length, dtype=int, padding="post", 
        truncating="post", value=tokenizer.pad_token_id))
    mention_ids_pt = torch.IntTensor(sequence.pad_sequences(mention_ids_list,
        maxlen=max_sequence_length, dtype=int, padding="post",
        truncating="post", value=0))
    label_ids_pt = torch.IntTensor(sequence.pad_sequences(label_ids_list,
        maxlen=max_sequence_length, dtype=int, padding="post", 
        truncating="post", value=0))
    attn_mask_pt = torch.FloatTensor(sequence.pad_sequences(attn_mask_list, 
        maxlen=max_sequence_length, dtype=float, padding="post",
        truncating="post", value=0.))
    global_attn_mask_pt = torch.FloatTensor(sequence.pad_sequences(
        global_attn_mask_list, 
        maxlen=max_sequence_length, dtype=float, padding="post",
        truncating="post", value=0.))
    doc_ids_pt = torch.IntTensor(doc_ids)

    dataset = tdata.TensorDataset(token_ids_pt, mention_ids_pt, label_ids_pt, 
                                attn_mask_pt, global_attn_mask_pt, doc_ids_pt)
    return dataset

In [5]:
tokenizer = model.tokenizer

In [6]:
def lcs(word_characters: str, token_characters: str, 
        ignore_token_characters: list[str] = []) -> list[int]:
    """Find the longest common subsequence between the word_characters and
    token_characters strings.
    """
    lcs_len = np.zeros((len(word_characters), len(token_characters)), dtype=int)
    lcs_dir = np.zeros((len(word_characters), len(token_characters)), dtype=int)

    for i in tqdm.trange(len(word_characters)):
        for j in range(len(token_characters)):
            equal = word_characters[i] == token_characters[j] and (
                token_characters[j] not in ignore_token_characters)
            if i == 0 and j == 0:
                if equal:
                    lcs_len[i, j] = 1
                    lcs_dir[i, j] = 1
                else:
                    lcs_len[i, j] = 0
                    lcs_dir[i, j] = 0
            elif i == 0:
                if equal:
                    lcs_len[i, j] = 1
                    lcs_dir[i, j] = 1
                else:
                    lcs_len[i, j] = lcs_len[i, j - 1]
                    lcs_dir[i, j] = 2
            elif j == 0:
                if equal:
                    lcs_len[i, j] = 1
                    lcs_dir[i, j] = 1
                else:
                    lcs_len[i, j] = lcs_len[i - 1, j]
                    lcs_dir[i, j] = 3
            else:
                if equal:
                    lcs_len[i, j] = 1 + lcs_len[i - 1, j - 1]
                    lcs_dir[i, j] = 1
                else:
                    if lcs_len[i, j - 1] > lcs_len[i - 1, j]:
                        lcs_len[i, j] = lcs_len[i, j - 1]
                        lcs_dir[i, j] = 2
                    else:
                        lcs_len[i, j] = lcs_len[i - 1, j]
                        lcs_dir[i, j] = 3
    
    word_character_to_token_character = np.zeros(len(word_characters), 
                                                dtype=int)
    i = len(word_characters) - 1
    j = len(token_characters) - 1
    while i >= 0 and j >= 0:
        if lcs_dir[i, j] == 1:
            word_character_to_token_character[i] = j
            i -= 1
            j -= 1
        elif lcs_dir[i, j] == 2:
            j -= 1
        elif lcs_dir[i, j] == 3:
            i -= 1
        else:
            i -= 1
            j -= 1

    return word_character_to_token_character.tolist()        

In [7]:
def naive_mapping(word_characters: str, token_characters: str, 
        ignore_token_characters: list[str] = []) -> list[int]:
    """Find the mapping between the word_characters and token_characters strings.
    """
    word_character_to_token_character = np.zeros(len(word_characters),
                                                dtype=int)
    i, j = 0, 0
    while i < len(word_characters) and j < len(token_characters):
        equal = word_characters[i] == token_characters[j] and (
                token_characters[j] not in ignore_token_characters)
        if equal:
            word_character_to_token_character[i] = j
            i += 1
            j += 1
        else:
            j += 1
    return word_character_to_token_character.tolist()

In [8]:
def remap_spans(
    document: data.CorefDocument, tokenize_fn: Callable[[str], list[str]], 
    verbose=False) -> data.CorefDocument:
    """Apply tokenize function at the document level by concatenating all
    words in the documents together. Then adjust the indices of the mentions
    in the coreference clusters, and the named entity and constituencies
    dictionaries. Prefer this over remap_spans_at_word_level if you are using
    transformer-based tokenizers.
    """
    new_document = data.CorefDocument()
    words = [word for sentence in document.sentences for word in sentence]
    text = " ".join(words)
    tokens = tokenize_fn(text)
    word_characters = "".join(words)
    token_characters = "".join(tokens)
    word_begin_to_word_character = np.zeros(len(words), dtype=int)
    word_end_to_word_character = np.zeros(len(words), dtype=int)
    token_character_to_token_index = np.zeros(len(token_characters),
                                            dtype=int)

    c = 0
    for i, word in enumerate(words):
        word_begin_to_word_character[i] = c
        word_end_to_word_character[i] = c + len(word) - 1
        c += len(word)

    word_character_to_token_character = naive_mapping(
        word_characters, token_characters, ignore_token_characters=["Ġ"])

    c = 0
    for i, token in enumerate(tokens):
        token_character_to_token_index[c: c + len(token)] = i
        c += len(token)

    def map_begin(word_begin: int) -> int:
        return token_character_to_token_index[
                word_character_to_token_character[
                    word_begin_to_word_character[word_begin]]]

    def map_end(word_end: int) -> int:
        return token_character_to_token_index[
                word_character_to_token_character[
                    word_end_to_word_character[word_end]]]

    for cluster in document.clusters:
        new_cluster: set[data.Mention] = set()
        for mention in cluster:
            new_begin = map_begin(mention.begin)
            new_end = map_end(mention.end)
            new_mention = data.Mention(new_begin, new_end)
            new_cluster.add(new_mention)
        new_document.clusters.append(new_cluster)

    for mention, ner_tag in document.named_entities.items():
        new_begin = map_begin(mention.begin)
        new_end = map_end(mention.end)
        new_mention = data.Mention(new_begin, new_end)
        new_document.named_entities[new_mention] = ner_tag

    for mention, constituency_tag in document.constituents.items():
        new_begin = map_begin(mention.begin)
        new_end = map_end(mention.end)
        new_mention = data.Mention(new_begin, new_end)
        new_document.constituents[new_mention] = constituency_tag

    new_sentences = []
    new_speakers = []
    i, j = 0, 0
    for sentence, speakers in zip(document.sentences, document.speakers):
        n_words = len(sentence)
        end = map_end(i + n_words - 1)
        new_sentence = tokens[j: end + 1]
        i += n_words
        j = end + 1
        new_sentences.append(new_sentence)
        new_speakers.append([speakers[0] for _ in range(len(new_sentence))])
    new_document.sentences = new_sentences
    new_document.speakers = new_speakers

    new_document.doc_key = document.doc_key
    new_document.doc_id = document.doc_id

    if verbose:
        print(f"Number of words = {len(words)}")
        print(words)
        print()
        print(f"Number of tokens = {len(tokens)}")
        print(tokens)
        print()
        print(f"Word characters: {word_characters}")
        print()
        print(f"Token characters: {token_characters}")
        print()
        print("Word begin to word character:")
        print(word_begin_to_word_character)
        print()
        print("Word end to word character:")
        print(word_end_to_word_character)
        print()
        print("Word character to token character:")
        print(word_character_to_token_character.tolist())
        print()
        print("Token character to token index")
        print(token_character_to_token_index.tolist())
    
    for i in range(len(word_character_to_token_character) - 1):
        assert word_character_to_token_character[i] <= (
                word_character_to_token_character[i + 1])

    return new_document

In [9]:
test_corpus = data.CorefCorpus("/home/sbaruah_usc_edu/mica_text_coref/data/"
                          "conll-2012/gold/test.english.jsonlines")
dev_corpus = data.CorefCorpus("/home/sbaruah_usc_edu/mica_text_coref/data/"
                          "conll-2012/gold/dev.english.jsonlines")
train_corpus = data.CorefCorpus("/home/sbaruah_usc_edu/mica_text_coref/data/"
                          "conll-2012/gold/train.english.jsonlines")
seq_test_corpus = data_util.remove_overlaps(test_corpus)
seq_dev_corpus = data_util.remove_overlaps(dev_corpus)
seq_train_corpus = data_util.remove_overlaps(train_corpus)
seq_corpus = seq_test_corpus + seq_dev_corpus + seq_train_corpus

Using '*' instead of '·'
Using '*' instead of '·'
Using '.' instead of '・'
Using '.' instead of '・'
Using '(www.mcoa.cn)' instead of '（www.mcoa.cn）'
Using '*' instead of '·'
Using '*' instead of '·'
Using '-LRB-c-RRB-?Io]o?' instead of '-LRB-c-RRB-?Ìö]o?'
Using 'UDI'' instead of 'ÛDÌ’'
Using 'Io]' instead of 'Ìö]'
Using 'Io...-LRB-c-RRB-x.goIo]' instead of 'Ìò...-LRB-c-RRB-x˙goÌö]'
Using 'iAnyway' instead of 'ِAnyway'
Using 'vis-a-vis' instead of 'vis-à-vis'
Using 'Y=' instead of '￥'
Using ':' instead of '：'
Using 'enosnail' instead of 'eのsnail'
Using '*Lingtai' instead of '＊Lingtai'
Using 'Shanyin*' instead of 'Shanyin＊'
Using '#' instead of '□'
Using '.' instead of '・'
Using '.' instead of '・'
Using '.' instead of '・'
Using '.' instead of '・'
Using '*' instead of '·'
Using '*' instead of '·'
Using '-' instead of '→'
Using '-' instead of '→'
Using '-' instead of '→'
Using '-' instead of '→'
Using '.' instead of '・'
Using '.' instead of '・'
Using '[(' instead of '【'
Using ')]' instead 

100%|██████████| 348/348 [00:01<00:00, 185.29it/s]
100%|██████████| 343/343 [00:01<00:00, 250.13it/s]
100%|██████████| 2802/2802 [00:12<00:00, 231.04it/s]


In [13]:
doc_ids = []
longformer_seq_corpus = data.CorefCorpus()
for document in seq_corpus.documents:
    try:
        longformer_document = remap_spans(document, model.tokenizer.tokenize)
        longformer_seq_corpus.documents.append(longformer_document)
    except AssertionError as e:
        doc_ids.append(document.doc_id)
        print(document.doc_id, document.doc_key)

In [11]:
for doc_id in doc_ids:
    document = seq_corpus.documents[doc_id]
    words = [word for sentence in document.sentences for word in sentence]
    text = " ".join(words)
    tokens = tokenizer.tokenize(text)
    word_characters = "".join(words)
    token_characters = "".join(tokens)
    word_character_to_token_character = lcs(word_characters, token_characters, 
                                            ignore_token_characters=["Ġ"])
    word_characters_not_matched = [word_characters[i] 
                                   for i in range(len(word_characters)) 
                                   if i > 0 
                                   and 
                                   word_character_to_token_character[i - 1] > 0 
                                   and 
                                   word_character_to_token_character[i] == 0
                                   ]
    print(f"doc_id = {doc_id}, doc_key = {document.doc_key}")
    print("word characters not matched:")
    print(word_characters_not_matched)
    print()

In [12]:
for document in seq_corpus.documents:
    for sentence in document.sentences:
        for word in sentence:
            whitespace_removed_word = re.sub(r"\s", "", word)
            if len(whitespace_removed_word) == 0:
                print(f"doc_id = {document.doc_id}, "
                       "doc_key = {document.doc_key}, whitespace word = {word}")

In [14]:
representative_mentions: list[list[data.Mention]] = []

for document in longformer_seq_corpus.documents:
    document_representative_mentions: list[data.Mention] = []
    for cluster in document.clusters:
        mention = representatives.representative_mention(cluster, document)
        document_representative_mentions.append(mention)
    representative_mentions.append(document_representative_mentions)

In [17]:
dataset = create_tensors(longformer_seq_corpus, representative_mentions, model)

In [19]:
len(dataset.tensors)

6

In [21]:
dataset.tensors[0].shape

torch.Size([43733, 4098])