# [Coreference Resolution through a seq2seq Transition-Based System](https://arxiv.org/abs/2211.12142)



```
@misc{bohnet2022coreference,
      title={Coreference Resolution through a seq2seq Transition-Based System}, 
      author={Bernd Bohnet and Chris Alberti and Michael Collins},
      year={2022},
      eprint={2211.12142},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

```

Adapted from the notebook:
https://github.com/google-research/google-research/tree/master/coref_mt5#coreference-resolution-through-a-seq2seq-transition-based-system

Adapted from the github:
https://github.com/ianporada/mt5_coref_pytorch

In [1]:
import torch
import pandas as pd
import nltk
from nltk import sent_tokenize
from datasets import Dataset
from transformers import MT5Tokenizer, T5ForConditionalGeneration
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [2]:
from state import State
from util import (create_document, create_next_batch, extract_result_string,
                  predict_coreferences, read_jsonl, write_jsonl)
from typing import List
from collections import defaultdict

In [3]:
import spacy
# nlp = spacy.load("en_core_web_sm")
# nlp = spacy.load("en_core_web_md")
nlp = spacy.load("en_core_web_lg")

## Load Dataset

In [4]:
dataset_path = "../data/aggre_fact_final.csv"
df = pd.read_csv(dataset_path, index_col=0)
dataset_final = Dataset.from_pandas(df, preserve_index=False)

  if _pandas_api.is_sparse(col):


## Load Model and Tokenizer

In [5]:
tokenizer_nltk = nltk.WordPunctTokenizer()
model_ckpt = "mt5-coref-pytorch/link-append-xxl"
tokenizer = MT5Tokenizer.from_pretrained(model_ckpt, legacy=False)
model = T5ForConditionalGeneration.from_pretrained(model_ckpt, 
                                                   torch_dtype=torch.float16).to(device)

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()


In [6]:
# xpand_only = False

pronouns = ["i", "he", "she", "you", "me", "him", "myself", "yourself", "himself", "herself", "yourselves"]
special_pronouns = ["my", "mine", "her", "hers", "his", "your", "yours"]

## Extract Coreferences

In [7]:
doc = dataset_final['doc'][0].title()

"""
Adapted from the github:
https://github.com/ianporada/mt5_coref_pytorch
"""
inputs = [{'document_id': 'example_doc', 'sentences':[]}]
sentences_list = sent_tokenize(doc)
for sentence in sentences_list:
    d = {'speaker': '_', 'words': tokenizer_nltk.tokenize(sentence)}
    inputs[0]['sentences'].append(d)

states_dict = {}
for doc in inputs:
    states_dict[doc['document_id']] = State(create_document(doc), tokenizer)

num_done = 0
while num_done < len(states_dict):  # while states
    states, batches = create_next_batch(states_dict)

    if not states:
        break

    documents_processing = set([x.input_document['doc_key'] for x in states])

    predictions = predict_coreferences(tokenizer, model, batches, len(batches))
    results = extract_result_string(predictions)

    for state, result, batch in zip(states, results, batches):
        state.extend(result)

"""
Adapted from the notebook:
https://github.com/google-research/google-research/tree/master/coref_mt5#coreference-resolution-through-a-seq2seq-transition-based-system
"""
for doc_name, s in states_dict.items():
    all_pred_clusters = [cluster for name, cluster in s.cluster_name_to_cluster.items()]

    text, text_map = [], []
    for k, snt in states_dict[doc_name].input_document['sentences'].items():
        m = states_dict[doc_name].input_document['token_maps'][k]
        text += snt
        text_map += m

    # custom
    words_dict = {}
    pred_clusters = []
    for pred_cluster in all_pred_clusters:
        person_flag = False
        for st, en in pred_cluster:
            head = " ".join(text[st:en+1]).title()
            head_nlp = nlp(head)
            if len(head_nlp.ents) >= 3:   #unnecessary
                continue
            for ent in head_nlp.ents:
                if ent.label_ == "PERSON":
                    person_entity_index = s.mention_index_to_cluster_name[str(tuple([st, en]))]
                    if person_entity_index not in words_dict.keys():
                        ent_text = ent.text
                        if "'s" in ent_text:
                            ent_text = ent_text.replace("'s", '')
                        elif " ' s" in ent_text:
                            ent_text = ent_text.replace(" ' s", '')
                        elif "' s" in ent_text:
                            ent_text = ent_text.replace("' s", '')
                        words_dict[person_entity_index] = ent_text
                    person_flag=True
                    break
        if person_flag:
            pred_clusters.append(pred_cluster)

    cluster_annotations_start = []
    cluster_annotations_end = []

    for tid in text_map:
        cluster_annotations_start.append([])
        cluster_annotations_end.append([])
        for ci in pred_clusters:
            for m in ci:
                if tid == m[0]:
                    m_len = m[1] - m[0]
                    name = s.mention_index_to_cluster_name[str(m)]
                    cluster_annotations_start[-1].append((name, m_len))

                if tid == m[1]:
                    cluster_annotations_end[-1].append(']')

    all_text = []
    resolved_text = []

    for tok, start, end in zip(text, cluster_annotations_start, cluster_annotations_end):
        is_resolved = False
        if start:
            for x in [start[0]]:
                lower_tok = tok.lower()
                if lower_tok in pronouns:
                    try:
                        resolved_text.append(words_dict[x[0]])
                        is_resolved = True
                    except:
                        continue
                elif lower_tok in special_pronouns:
                    try:
                        resolved_text.append(words_dict[x[0]] + "'s")
                        is_resolved = True
                    except:
                        continue
                else:
                    tok_nlp = nlp(tok)
                    if tok_nlp.text == "ms" or tok_nlp.text == "mr":
                        break
                    for ent in tok_nlp.ents:
                        if ent.label_ == "PERSON" or ent.label_ == "ORG":
                            break
                    else:
                        try:
                            resolved_text.append(words_dict[x[0]] + ',')
                        except:
                            continue
        if not is_resolved:
            resolved_text.append(tok.lower())

    for tok, start, end in zip(text, cluster_annotations_start, cluster_annotations_end):
        if start:
            for x in sorted(start, key=lambda x : x[1], reverse=True):
                all_text.append('['+str(x[0]))

        all_text.append(tok.lower())

        if end:
            all_text.append(''.join(end))

    print()
    print(' '.join(all_text))
    print()
    print(' '.join(resolved_text))


france ' s dubuisson carded a 67 to tie with overnight leader van zyl of south africa on 16 under par . [1 mcilroy ] carded a third straight five under - par 67 to move to 15 under par with thailand ' s kiradech aphibarnrat . [1 the world number three ' s ] round included an eagle on the 12th as [1 he ] bids to win [1 his ] first title since may . " the 67s [1 i ] ' ve shot this week have all been a little different and [1 i ] feel like [1 i ] ' ve played within [1 myself ] for all of them , " said [1 four - time major winner mcilroy of northern ireland ] . " [1 i ] feel there ' s a low round out there for [1 me ] and hopefully it ' s tomorrow ." [1 mcilroy ] was level par for the day after 10 holes , dropping [1 his ] first shots of the week by three - putting the third and 10th , the latter mistake prompting [1 the 26 - year - old ] to throw [1 his ] putter at [1 his ] bag . but [1 he ] hit back with a birdie on the par - five 11th and a towering four iron from 229 yards on the 13th