In [13]:
import re
import spacy
import pandas as pd
import numpy as np
from tqdm import tqdm
from scipy.optimize import linear_sum_assignment
from collections import Counter, defaultdict
from allennlp.predictors.predictor import Predictor
from evaluate_by_joining_elements import evaluate_coreference_by_joining_elements

In [3]:
predictor = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/coref-spanbert-large-2020.02.27.tar.gz")

Did not use initialization regex that was passed: _context_layer._module.weight_ih.*
Did not use initialization regex that was passed: _context_layer._module.weight_hh.*


In [4]:
basterds_result = evaluate_coreference_by_joining_elements("data/annotation/basterds.script_parsed.txt", "data/annotation/basterds.coref.mapped.csv", -1, use_speaker_sep=True, coreference_model=predictor)

loading spacy model


  1%|          | 6/591 [00:00<00:10, 58.33it/s]

spacy tokenization of screenplay elements


100%|██████████| 591/591 [00:04<00:00, 121.74it/s]


finding global gold mention positions
	1008 gold mentions
	988 (98.01587301587301%) gold mentions found after parse
	980 (97.22222222222223%) gold mentions' spacy tokenization span found
finding gold clusters
23 gold clusters
finding sys clusters
	using 'says' after character names
	allennlp coreference resolution




40 sys clusters


MUC  : P = 0.7876 R = 0.6782 F1 = 0.7288
B3   : P = 0.5691 R = 0.2532 F1 = 0.3505
CEAFe: P = 0.2100 R = 0.3653 F1 = 0.2667
CoNLL 2012 score: 0.4487


In [36]:
bourne_result = evaluate_coreference_by_joining_elements("data/annotation/bourne.script_parsed.txt", "data/annotation/bourne.coref.mapped.csv", -1, use_speaker_sep=True, coreference_model=predictor)

loading spacy model


  2%|▏         | 13/649 [00:00<00:05, 122.82it/s]

spacy tokenization of screenplay elements


100%|██████████| 649/649 [00:05<00:00, 123.79it/s]


finding global gold mention positions
	911 gold mentions
	894 (98.13391877058177%) gold mentions found after parse
	887 (97.36553238199781%) gold mentions' spacy tokenization span found
finding gold clusters
38 gold clusters
finding sys clusters
	using 'says' after character names
	allennlp coreference resolution
18 sys clusters


MUC  : P = 0.9129 R = 0.8021 F1 = 0.8539
B3   : P = 0.8308 R = 0.6253 F1 = 0.7135
CEAFe: P = 0.6887 R = 0.3262 F1 = 0.4427
CoNLL 2012 score: 0.6701


In [37]:
shawshank_result = evaluate_coreference_by_joining_elements("data/annotation/shawshank.script_parsed.txt", "data/annotation/shawshank.coref.mapped.csv", -1, use_speaker_sep=True, coreference_model=predictor)

loading spacy model


  2%|▏         | 8/525 [00:00<00:07, 70.12it/s]

spacy tokenization of screenplay elements


100%|██████████| 525/525 [00:04<00:00, 113.82it/s]


finding global gold mention positions
	888 gold mentions
	881 (99.21171171171171%) gold mentions found after parse
	880 (99.09909909909909%) gold mentions' spacy tokenization span found
finding gold clusters
44 gold clusters
finding sys clusters
	using 'says' after character names
	allennlp coreference resolution
26 sys clusters


MUC  : P = 0.9143 R = 0.8170 F1 = 0.8629
B3   : P = 0.6874 R = 0.6628 F1 = 0.6749
CEAFe: P = 0.5728 R = 0.3384 F1 = 0.4255
CoNLL 2012 score: 0.6544


In [34]:
def nec_f1_score(gold: set, sys: set):
    u, v, w = len(gold.intersection(sys)), len(gold), len(sys)
    if v:
        return 2 * u / (v + w)
    else:
        return int(w == 0)

def evaluate_coreference_nec(gold_entity_to_mentions, sys_clusters, coref_df, document):
    pronouns = "I, me, my, mine, myself, We, us, our, ours, ourselves, you, your, yours, yourself, yourselves, he, him, his, himself, she, her, hers, herself, it, its, itself, they, them, their, theirs, themself, themselves".lower().split(", ")

    coref_df.PRONOUN |= coref_df.mention.str.lower().isin(pronouns)
    document_pronouns = set(coref_df[coref_df.PRONOUN].mention.str.lower().unique()).union(pronouns)

    spacy_nlp = spacy.load("en_core_web_sm")
    spacy_document = spacy_nlp(document)

    _gold_entity_to_mentions = defaultdict(set)
    gold_entity_to_name_mentions = defaultdict(set)
    gold_entity_to_names = defaultdict(set)
    gold_entity_to_pronoun_mentions = defaultdict(set)
    gold_entity_to_nominal_mentions = defaultdict(set)
    sys_pronoun_clusters = []
    sys_nominal_clusters = []
    sys_name_clusters = []

    print("finding gold names, pronouns and nominals mentions")
    for entity, df in coref_df.groupby("entityLabel"):
        name_mentions = set()
        pronoun_mentions = set()
        nominal_mentions = set()

        for _, row in df.iterrows():
            mention = (row.mention_start, row.mention_end)
            if row.PRONOUN:
                pronoun_mentions.add(mention)
            elif row.NOMINAL:
                nominal_mentions.add(mention)
            else:
                name_mentions.add(mention)
            
        name_mentions.intersection_update(gold_entity_to_mentions[entity])
        pronoun_mentions.intersection_update(gold_entity_to_mentions[entity])
        nominal_mentions.intersection_update(gold_entity_to_mentions[entity])

        if name_mentions:
            gold_entity_to_name_mentions[entity] = name_mentions
            gold_entity_to_pronoun_mentions[entity] = pronoun_mentions
            gold_entity_to_nominal_mentions[entity] = nominal_mentions
            _gold_entity_to_mentions[entity] = name_mentions.union(pronoun_mentions).union(nominal_mentions)

    gold_entity_to_mentions = _gold_entity_to_mentions
    entities = list(gold_entity_to_mentions.keys())

    print("finding names of entities")
    for entity in entities:
        name_mentions = gold_entity_to_name_mentions[entity]
        names = set()

        for i, j in name_mentions:
            char_begin = spacy_document[i].idx
            char_end = spacy_document[j].idx + len(spacy_document[j])
            text = document[char_begin: char_end]
            text = re.sub("\s+", " ", text).strip()
            spacy_text = spacy_nlp(text)
            head_token = [token for token in spacy_text if token.head == token][0]

            for noun_chunk in spacy_text.noun_chunks:
                contains_proper_noun = any([token.pos_ == "PROPN" for token in noun_chunk])
                if contains_proper_noun and not noun_chunk.text.islower() and head_token in noun_chunk:
                    names.add(noun_chunk.text.lower())
            
            names.add(text.lower())
        
        gold_entity_to_names[entity] = names

    print("finding sys names, pronouns and nominals mentions")
    for mentions in sys_clusters:
        pronoun_mentions = set()
        nominal_mentions = set()
        name_mentions = set()

        for i, j in mentions:
            char_begin = spacy_document[i].idx
            char_end = spacy_document[j].idx + len(spacy_document[j])
            text = document[char_begin: char_end]
            text = re.sub("\s+", " ", text).strip()
            spacy_text = spacy_nlp(text)
            head_token = [token for token in spacy_text if token.head == token][0]

            if text.lower() in document_pronouns:
                pronoun_mentions.add((i, j))
            elif head_token.pos_ == "PROPN":
                name_mentions.add((i, j))
            else:
                nominal_mentions.add((i, j))

        sys_pronoun_clusters.append(pronoun_mentions)
        sys_nominal_clusters.append(nominal_mentions)
        sys_name_clusters.append(name_mentions)

    nec_f1_mat = np.zeros((len(entities), len(sys_clusters)))

    print("calculating nec f1")
    for i, entity in enumerate(entities):
        for j, sys_mentions in enumerate(sys_clusters):
            for k, l in sys_mentions:
                char_begin = spacy_document[k].idx
                char_end = spacy_document[l].idx + len(spacy_document[l])
                text = document[char_begin: char_end]
                text = re.sub("\s+", " ", text).strip().lower()
                contains_gold_name = any([re.search("(^|\s)" + re.escape(name) + "(\s|$)", text) is not None for name in gold_entity_to_names[entity]])

                if contains_gold_name:
                    nec_f1_mat[i, j] = nec_f1_score(gold_entity_to_mentions[entity], sys_mentions)
                    break

    row_ind, col_ind = linear_sum_assignment(nec_f1_mat, maximize=True)
    nec_f1 = nec_f1_mat[row_ind, col_ind].sum()/len(entities)
    n_unmatched = (nec_f1_mat[row_ind, col_ind] == 0).sum()

    nec_pronoun_f1 = 0
    nec_nominal_f1 = 0
    nec_name_f1 = 0

    for r, c in zip(row_ind, col_ind):
        entity = entities[r]
        nec_pronoun_f1 += nec_f1_score(gold_entity_to_pronoun_mentions[entity], sys_pronoun_clusters[c])
        nec_nominal_f1 += nec_f1_score(gold_entity_to_nominal_mentions[entity], sys_nominal_clusters[c])
        nec_name_f1 += nec_f1_score(gold_entity_to_name_mentions[entity], sys_name_clusters[c])

    nec_pronoun_f1 /= len(entities)
    nec_nominal_f1 /= len(entities)
    nec_name_f1 /= len(entities)

    print(f"NEC F1 = {nec_f1:.4f}, chains missed = {n_unmatched} ({100*n_unmatched/len(entities):.2f}%)")
    print(f"NEC F1 for pronouns = {nec_pronoun_f1:.4f}, nominals = {nec_nominal_f1:.4f}, names = {nec_name_f1:.4f}")

    nec_result = {"nec_f1": nec_f1, "n_unmatched": n_unmatched, "nec_pronoun_f1": nec_pronoun_f1, "nec_nominal_f1": nec_nominal_f1, "nec_name_f1": nec_name_f1}
    meta_info = {"gold_entities": entities, "gold_ind": row_ind.tolist(), "sys_ind": col_ind.tolist()}
    nec_result["meta_info"] = meta_info
    return nec_result

In [5]:
gold_entity_to_cluster = basterds_result["gold_clusters"]
sys_clusters = basterds_result["sys_clusters"]
coref_df = basterds_result["coref_dataframe"]
document = basterds_result["document"]

In [35]:
basterds_nec_result = evaluate_coreference_nec(basterds_result["gold_clusters"], basterds_result["sys_clusters"], basterds_result["coref_dataframe"], basterds_result["document"])

finding gold names, pronouns and nominals mentions
finding names of entities
finding sys names, pronouns and nominals mentions
calculating nec f1
NEC F1 = 0.4201, chains missed = 6 (30.00%)
NEC F1 for pronouns = 0.5357, nominals = 0.3763, names = 0.3513


In [38]:
bourne_nec_result = evaluate_coreference_nec(bourne_result["gold_clusters"], bourne_result["sys_clusters"], bourne_result["coref_dataframe"], bourne_result["document"])

finding gold names, pronouns and nominals mentions
finding names of entities
finding sys names, pronouns and nominals mentions
calculating nec f1
NEC F1 = 0.4583, chains missed = 5 (41.67%)
NEC F1 for pronouns = 0.3680, nominals = 0.7500, names = 0.4461


In [39]:
shawshank_nec_result = evaluate_coreference_nec(shawshank_result["gold_clusters"], shawshank_result["sys_clusters"], shawshank_result["coref_dataframe"], shawshank_result["document"])

finding gold names, pronouns and nominals mentions
finding names of entities
finding sys names, pronouns and nominals mentions
calculating nec f1
NEC F1 = 0.5103, chains missed = 7 (35.00%)
NEC F1 for pronouns = 0.5285, nominals = 0.5160, names = 0.5685
