# Coreference resolution using Fastcoref

```console
mamba install gradio
pip install fastcoref
```

## Resources
 
 - https://pypi.org/project/fastcoref/
 - https://neurosys.com/blog/intro-to-coreference-resolution-in-nlp
 - https://huggingface.co/spaces/pythiccoder/FastCoref/blob/main/app.py
 - [Stanford CS224n Coreference notes](https://web.stanford.edu/class/archive/cs/cs224n/cs224n.1162/handouts/cs224n-lecture10-coreference.pdf)


## Demo Code from HF Demo

Actual demo is offline but the code looks simple enough to understand and it uses the Spacy span visualizer as well. So a good model for a CoRef test bench. It works just fine in the notebook.

In [2]:
from random import randint

# from https://www.color-hex.com/color-palette/1041293
pastel_colors = [
# rainbow
'#ffb3ba',
'#ffdfba',
'#ffffba',
'#baffc9',
'#bae1ff',
# Spring valentine
'#8478bf',
'#bbabdb',
'#ffe4e1',
'#ffeb8e',
'#ffc350'
]

def get_colors(num):
    if num < len(pastel_colors):
        return pastel_colors
    
    rand_colors = ['#%06X' % randint(0, 0xFFFFFF) for i in num - len(pastel_colors)]
    return pastel_colors + rand_colors

In [3]:
import spacy
from spacy import displacy
from spacy.tokens import span

from random import randint

# Replace with fastcoref for the fast model
from fastcoref import LingMessCoref

import gradio as gr

# uses a 2.36G weights file
model = LingMessCoref()
nlp = spacy.blank("en")

default = "Lionel Messi has won a record seven Ballon d'Or awards. He signed for Paris Saint-Germain in August 2021. “I would like to thank my family” said the Argentinian footballer. Messi holds the records for most goals in La Liga. Paris Saint-Germain hopes he will do the same in Ligue 1."

def coref(text):
    # Pre-process text copied from PDF
    text = text.replace('\t', ' ').replace('\n', ' ')

    preds = model.predict(texts=[text])

    # as_strings=False returns spans
    # as_strings=True returns actual mentions and resolutions
    clusters = preds[0].get_clusters(as_strings=False)

    # Spacy doc just for visualizing the spans via DisplaCy.
    doc = nlp(text)
    doc.spans["sc"] = []

    # Assign a color for each cluster. Make sure they are named the same.
    color_hexs = get_colors(len(clusters))
    color_keys = {"cluster{}".format(i): color_hexs[i] for i in range(len(clusters))}

    for i, cluster in enumerate(clusters):
        for sp in cluster:
            doc.spans["sc"] += [doc.char_span(sp[0], sp[1], "cluster{}".format(i))]
    return displacy.render(doc, style="span", options={"colors":color_keys}, page=True)

iface = gr.Interface(fn=coref, 
                     inputs=gr.Textbox(label="Enter Text To Corefer with FastCoref", lines=2, value=default), 
                     outputs="html")
iface.launch()

02/23/2024 17:51:24 - INFO - 	 HTTP Request: GET https://api.gradio.app/gradio-messaging/en "HTTP/1.1 200 OK"
Some weights of the model checkpoint at biu-nlp/lingmess-coref were not used when initializing LingMessModel: ['longformer.embeddings.position_ids']
- This IS expected if you are initializing LingMessModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LingMessModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
02/23/2024 17:51:36 - INFO - 	 missing_keys: []
02/23/2024 17:51:36 - INFO - 	 unexpected_keys: []
02/23/2024 17:51:36 - INFO - 	 mismatched_keys: []
02/23/2024 17:51:36 - INFO - 	 error_msgs: []
02/23/2024 17:51:36 - INFO - 	 Model Parameters: 590.0M, Transformer: 434.6M, Coref 

Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.




02/23/2024 17:51:36 - INFO - 	 HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
02/23/2024 17:51:36 - INFO - 	 HTTP Request: GET https://checkip.amazonaws.com/ "HTTP/1.1 200 "
02/23/2024 17:51:37 - INFO - 	 HTTP Request: GET https://checkip.amazonaws.com/ "HTTP/1.1 200 "
02/23/2024 17:51:37 - INFO - 	 HTTP Request: POST https://api.gradio.app/gradio-initiated-analytics/ "HTTP/1.1 200 OK"
02/23/2024 17:51:37 - INFO - 	 HTTP Request: POST https://api.gradio.app/gradio-launched-telemetry/ "HTTP/1.1 200 OK"
02/23/2024 17:51:41 - INFO - 	 Tokenize 1 inputs...


Map:   0%|          | 0/1 [00:00<?, ? examples/s]

02/23/2024 17:51:41 - INFO - 	 ***** Running Inference on 1 texts *****


Inference:   0%|          | 0/1 [00:00<?, ?it/s]

# Explore coref results



In [4]:
preds = model.predict(texts=[default])

02/23/2024 17:52:03 - INFO - 	 Tokenize 1 inputs...


Map:   0%|          | 0/1 [00:00<?, ? examples/s]

02/23/2024 17:52:03 - INFO - 	 ***** Running Inference on 1 texts *****


Inference:   0%|          | 0/1 [00:00<?, ?it/s]

In [5]:
# Explore the types
# This returns items of type: fastcoref.modeling.CorefResult
print(f"Type of the preds is {type(preds[0])}")

# A cluster is a list of all mentions of a particular named entity ?
# Each mention (item in a cluster) is a tuple (a,b) which are indices to a span of text
#   text[a:b] (index a till b inclusive)
print(f"Type of each cluster inside a pred is {type(preds[0].clusters[0])}")
for c in preds[0].clusters:
    print(c)

Type of the preds is <class 'fastcoref.modeling.CorefResult'>
Type of each cluster inside a pred is <class 'tuple'>
((1, 2), (15, 15), (29, 29), (34, 34), (39, 42), (44, 44), (61, 61))
((18, 22), (55, 59))


This class, from ` ` has the following definition

```python
class CorefResult:
    def __init__(self, text, clusters, char_map, reverse_char_map, coref_logit, text_idx):
        self.text = text
        self.clusters = clusters
        self.char_map = char_map
        self.reverse_char_map = reverse_char_map
        self.coref_logit = coref_logit
        self.text_idx = text_idx

    def get_clusters(self, as_strings=True):
        if not as_strings:
            return [[self.char_map[mention][1] for mention in cluster] for cluster in self.clusters]

        return [[self.text[self.char_map[mention][1][0]:self.char_map[mention][1][1]]
                 for mention in cluster if None not in self.char_map[mention]] for cluster in self.clusters]

    def get_logit(self, span_i, span_j):
        if span_i not in self.reverse_char_map:
            raise ValueError(f'span_i="{self.text[span_i[0]:span_i[1]]}" is not an entity in this model!')
        if span_j not in self.reverse_char_map:
            raise ValueError(f'span_i="{self.text[span_j[0]:span_j[1]]}" is not an entity in this model!')

        span_i_idx = self.reverse_char_map[span_i][0]   # 0 is to get the span index
        span_j_idx = self.reverse_char_map[span_j][0]

        if span_i_idx < span_j_idx:
            return self.coref_logit[span_j_idx, span_i_idx]

        return self.coref_logit[span_i_idx, span_j_idx]

    def __str__(self):
        if len(self.text) > 50:
            text_to_print = f'{self.text[:50]}...'
        else:
            text_to_print = self.text
        return f'CorefResult(text="{text_to_print}", clusters={self.get_clusters()})'

    def __repr__(self):
        return self.__str__()


```

In [5]:
"""Barack Obama nominated Hillary Rodham Clinton as his	
secretary	of	state	on	Monday.	He	chose	her	because	she	
had	foreign	affairs	experience	as	a	former	First	Lady.""".replace('\t', ' ').replace('\n',' ')

'Barack Obama nominated Hillary Rodham Clinton as his  secretary of state on Monday. He chose her because she  had foreign affairs experience as a former First Lady.'

## Explore samples from Stanford CS224n course

While looking at some of the samples from [Stanfords CS224n from way back](https://web.stanford.edu/class/archive/cs/cs224n/cs224n.1162/handouts/cs224n-lecture10-coreference.pdf), I see that the references they display are more than the references that LingMess (and maybe the others) will display. Need to figure that out later.

### Barack Hillary

**"Barack Obama nominated Hillary Rodham Clinton as his secretary of state on Monday. He chose her because she had foreign affairs experience as a former First Lady."**

| Context | Output | Deltas |
| :-- | :-- | :-- | 
| CS224n Expectation | ![](./img/coref-barack-hillary.png) | |
| LingMess Output | ![](./img/coref-barack-hillary-lingmess.png) | missing _His secretary of state_, _first lady_|

### CEO and his pay

"John Smith, CFO of Prime Corp. since 1986, saw his pay jump 20% to $1.3 million as the 57-year-old also became the financial services co.’s president"S

| Context | Output | Deltas |
| :-- | :-- | :-- | 
| CS224n Expectation | ![](./img/coref-ceo-pay-expected.png) | |
| LingMess Output | ![](./img/coref-ceo-pay-lingmess.png) | missing _his pay_, _1.3M$_, _The financial services co's president_|

In [6]:
def print_preds(preds):
    # Hmm, there is only one pred. When do yu have more than one pred ? When there are multiple possible
    # interpretations ? In that case, is there a probabilistic weight to each ?
    for i, pred in enumerate(preds):
        print(f" Pred[{i}] ---")
        print(f" The clusters as strings are - {pred.get_clusters()}")
        print(f" The clusters as spans are - {pred.get_clusters(as_strings=False)}")

samples = [
"Barack Obama nominated Hillary Rodham Clinton as his secretary of state on Monday. He chose her because she had foreign affairs experience as a former First Lady.",
"John Smith, CFO of Prime Corp. since 1986, saw his pay jump 20% to $1.3 million as the 57-year-old also became the financial services co.’s president"
]

for sample in samples:
    preds = model.predict(texts = [sample])
    print_preds(preds)

02/23/2024 17:52:42 - INFO - 	 Tokenize 1 inputs...


Map:   0%|          | 0/1 [00:00<?, ? examples/s]

02/23/2024 17:52:42 - INFO - 	 ***** Running Inference on 1 texts *****


Inference:   0%|          | 0/1 [00:00<?, ?it/s]

02/23/2024 17:52:43 - INFO - 	 Tokenize 1 inputs...


 Pred[0] ---
 The clusters as strings are - [['Barack Obama', 'his', 'He'], ['Hillary Rodham Clinton', 'her', 'she']]
 The clusters as spans are - [[(0, 12), (49, 52), (83, 85)], [(23, 45), (92, 95), (104, 107)]]


Map:   0%|          | 0/1 [00:00<?, ? examples/s]

02/23/2024 17:52:44 - INFO - 	 ***** Running Inference on 1 texts *****


Inference:   0%|          | 0/1 [00:00<?, ?it/s]

 Pred[0] ---
 The clusters as strings are - [['John Smith, CFO of Prime Corp. since 1986,', 'his', 'the 57-year-old'], ['Prime Corp.', 'the financial services co.’s']]
 The clusters as spans are - [[(0, 42), (47, 50), (83, 98)], [(19, 30), (111, 139)]]


02/23/2024 18:02:23 - INFO - 	 Tokenize 1 inputs...


Map:   0%|          | 0/1 [00:00<?, ? examples/s]

02/23/2024 18:02:23 - INFO - 	 ***** Running Inference on 1 texts *****


Inference:   0%|          | 0/1 [00:00<?, ?it/s]

02/23/2024 18:09:26 - INFO - 	 Tokenize 1 inputs...


Map:   0%|          | 0/1 [00:00<?, ? examples/s]

02/23/2024 18:09:26 - INFO - 	 ***** Running Inference on 1 texts *****


Inference:   0%|          | 0/1 [00:00<?, ?it/s]