In [33]:
import pandas as pd
import numpy as np
from collections import Counter, defaultdict
from evaluate_by_joining_elements import evaluate_coreference_by_joining_elements
import spacy
from tqdm import tqdm, trange

from allennlp.predictors.predictor import Predictor
import allennlp_models.tagging


In [40]:
shawshank_result = evaluate_coreference_by_joining_elements("../../data/annotation/acl21/shawshank.script_parsed.txt", "../../data/annotation/acl21/shawshank.coref.mapped.csv", use_speaker_sep=True, keep_speaker_sys_clusters=True, heuristic_speaker_resolution=True)

print("\n\n\n")

basterds_result = evaluate_coreference_by_joining_elements("../../data/annotation/acl21/basterds.script_parsed.txt", "../../data/annotation/acl21/basterds.coref.mapped.csv", use_speaker_sep=True, keep_speaker_sys_clusters=True, heuristic_speaker_resolution=True)

print("\n\n\n")

bourne_result = evaluate_coreference_by_joining_elements("../../data/annotation/acl21/bourne.script_parsed.txt", "../../data/annotation/acl21/bourne.coref.mapped.csv", use_speaker_sep=True, keep_speaker_sys_clusters=True, heuristic_speaker_resolution=True)

loading spacy model


  2%|▏         | 10/525 [00:00<00:05, 86.90it/s]

spacy tokenization of screenplay elements


100%|██████████| 525/525 [00:05<00:00, 104.07it/s]


finding global gold mention positions
	888 gold mentions
	881 (99.21%) gold mentions found after parse
	880 (99.10%) gold mentions' spacy tokenization span found
finding gold clusters
44 gold clusters
loading allennlp coreference model
finding sys clusters
	using 'says' after character names
	allennlp coreference resolution
	spacy ner on document
	keeping speaker sys clusters
	heuristic speaker clustering
20 sys clusters
MUC  : P = 0.9055 R = 0.8254 F1 = 0.8636
B3   : P = 0.6002 R = 0.7177 F1 = 0.6537
CEAFe: P = 0.6685 R = 0.3039 F1 = 0.4178
CoNLL 2012 score: 0.6450




loading spacy model


  2%|▏         | 14/591 [00:00<00:04, 133.59it/s]

spacy tokenization of screenplay elements


100%|██████████| 591/591 [00:05<00:00, 111.80it/s]


finding global gold mention positions
	1008 gold mentions
	988 (98.02%) gold mentions found after parse
	980 (97.22%) gold mentions' spacy tokenization span found
finding gold clusters
23 gold clusters
loading allennlp coreference model
finding sys clusters
	using 'says' after character names
	allennlp coreference resolution
	spacy ner on document
	keeping speaker sys clusters
	heuristic speaker clustering
43 sys clusters
MUC  : P = 0.7938 R = 0.7482 F1 = 0.7703
B3   : P = 0.4753 R = 0.4514 F1 = 0.4630
CEAFe: P = 0.2158 R = 0.4035 F1 = 0.2812
CoNLL 2012 score: 0.5049




loading spacy model


  2%|▏         | 13/649 [00:00<00:05, 125.16it/s]

spacy tokenization of screenplay elements


100%|██████████| 649/649 [00:05<00:00, 125.73it/s]


finding global gold mention positions
	911 gold mentions
	894 (98.13%) gold mentions found after parse
	887 (97.37%) gold mentions' spacy tokenization span found
finding gold clusters
38 gold clusters
loading allennlp coreference model
finding sys clusters
	using 'says' after character names
	allennlp coreference resolution
	spacy ner on document
	keeping speaker sys clusters
	heuristic speaker clustering
18 sys clusters
MUC  : P = 0.9027 R = 0.8198 F1 = 0.8593
B3   : P = 0.8065 R = 0.7005 F1 = 0.7498
CEAFe: P = 0.5426 R = 0.2570 F1 = 0.3488
CoNLL 2012 score: 0.6526


In [5]:
spacy_pipeline = spacy.load("en_core_web_sm")

In [7]:
predictor = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/structured-prediction-srl-bert.2020.12.15.tar.gz")

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=570.0), HTML(value='')))




HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=231508.0), HTML(value='')))




HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=466062.0), HTML(value='')))




HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=440473133.0), HTML(value='')))




In [41]:
agent_dicts = []
patient_dicts = []
agent_patient_dicts = []

for name, result in [["shawshank", shawshank_result], ["basterds", basterds_result], ["bourne", bourne_result]]:

    lines = open(f"../../data/annotation/acl21/{name}.script_parsed.txt").read().strip().split("\n")
    tags, elements, spacy_docs, count_spacy_docs, mention_tags = [], [], [], [], []

    for line_index in trange(0, len(lines)):
        line = lines[line_index]
        tag, element = line[0], line[2:].strip()
        spacy_doc = spacy_pipeline(element)
        tags.append(tag)
        elements.append(element)
        spacy_docs.append(spacy_doc)
        count_spacy_docs.append(len(spacy_doc))
        mention_tags.extend([tag] * len(spacy_doc) + ["X"])
    mention_tags = mention_tags[:-1]

    srl_result = []

    for doc in tqdm(spacy_docs):
        doc_srl_result = []
        for sent in doc.sents:
            doc_srl_result.append(predictor.predict(sentence=sent.text))
        srl_result.append(doc_srl_result)

    coref_df = result["coref_dataframe"]

    agent_dict = defaultdict(lambda: defaultdict(int))
    patient_dict = defaultdict(lambda: defaultdict(int))
    agent_patient_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))

    for i in range(len(srl_result)):
        element_srl = srl_result[i]
        sent_docs = list(spacy_docs[i].sents)
        for j in range(len(element_srl)):
            sent_srl = element_srl[j]
            offset = sum(count_spacy_docs[:i]) + i + sent_docs[j][0].i

            for srl in sent_srl["verbs"]:
                action = srl["verb"]
                tags = srl["tags"]

                arg0 = [-1, -1]
                arg1 = [-1, -1]
                k = 0
                while k < len(tags):
                    if tags[k] == "B-ARG0":
                        l = k + 1
                        while l < len(tags) and tags[l] == "I-ARG0":
                            l += 1
                        arg0 = [k, l - 1]
                        k = l
                    elif tags[k] == "B-ARG1":
                        l = k + 1
                        while l < len(tags) and tags[l] == "I-ARG1":
                            l += 1
                        arg1 = [k, l - 1]
                        k = l
                    else:
                        k += 1

                if arg0[0] != -1:
                    arg0 = [arg0[0] + offset, arg0[1] + offset]
                    arg0_entity = coref_df.loc[(coref_df.mention_start == arg0[0]) & (coref_df.mention_end == arg0[1]), "entityLabel"].values
                    if len(arg0_entity) == 1:
                        arg0_entity = arg0_entity[0]
                        agent_dict[arg0_entity][action] += 1

                if arg1[0] != -1:
                    arg1 = [arg1[0] + offset, arg1[1] + offset]
                    arg1_entity = coref_df.loc[(coref_df.mention_start == arg1[0]) & (coref_df.mention_end == arg1[1]), "entityLabel"].values
                    if len(arg1_entity) == 1:
                        arg1_entity = arg1_entity[0]
                        patient_dict[arg1_entity][action] += 1

                if isinstance(arg0_entity, str) and isinstance(arg1_entity, str):
                    agent_patient_dict[arg0_entity][arg1_entity][action] += 1
    
    agent_dicts.append(agent_dict)
    patient_dicts.append(patient_dict)
    agent_patient_dicts.append(agent_patient_dict)
    
    print("\n\n")

100%|██████████| 525/525 [00:04<00:00, 108.39it/s]
100%|██████████| 525/525 [02:17<00:00,  3.81it/s]
  2%|▏         | 14/591 [00:00<00:04, 138.23it/s]






100%|██████████| 591/591 [00:05<00:00, 106.23it/s]
100%|██████████| 591/591 [02:16<00:00,  4.32it/s]
  2%|▏         | 13/649 [00:00<00:05, 124.07it/s]






100%|██████████| 649/649 [00:05<00:00, 126.13it/s]
100%|██████████| 649/649 [02:23<00:00,  4.51it/s]







In [42]:
agent_patient_dicts[2]

defaultdict(<function __main__.<lambda>()>,
            {'Bourne': defaultdict(<function __main__.<lambda>.<locals>.<lambda>()>,
                         {'Bourne': defaultdict(int,
                                      {'coming': 2,
                                       'skyline': 1,
                                       'is': 2,
                                       "'d": 1,
                                       'be': 2,
                                       'runs': 1,
                                       'rushes': 1,
                                       'Box': 2,
                                       'waiting': 1,
                                       "'ve": 1,
                                       'Take': 1,
                                       'struggles': 1,
                                       'moving': 1,
                                       'disappears': 2,
                                       'looks': 3,
                                       'forced': 1,


In [45]:
agent_patient_dicts[0]

defaultdict(<function __main__.<lambda>()>,
            {"ANDY'S WIFE'S LOVER": defaultdict(<function __main__.<lambda>.<locals>.<lambda>()>,
                         {"ANDY'S WIFE": defaultdict(int,
                                      {'slams': 1,
                                       'enters': 1,
                                       'carries': 1})}),
             "ANDY'S WIFE": defaultdict(<function __main__.<lambda>.<locals>.<lambda>()>,
                         {"ANDY'S WIFE": defaultdict(int,
                                      {'cries': 1,
                                       'shivering': 1,
                                       'was': 1,
                                       'stay': 1,
                                       'had': 1}),
                          'ANDY': defaultdict(int, {'stands': 1})}),
             'READER': defaultdict(<function __main__.<lambda>.<locals>.<lambda>()>,
                         {'ANDY': defaultdict(int, {'is': 1, 'are': 1}),
         

In [46]:
results_df = pd.read_csv("../../results/acl21/coreference_evaluation.all.csv", index_col=None)

In [47]:
results_df.columns

Index(['script', 'keep_only_speaker_gold_clusters',
       'remove_singleton_gold_clusters', 'heuristic_pronoun_resolution',
       'use_speaker_sep', 'keep_person_sys_clusters',
       'keep_speaker_sys_clusters', 'heuristic_speaker_resolution',
       'min_speaker_sim', 'max_speaker_merges', 'muc_R', 'muc_P', 'muc_F1',
       'bcubed_R', 'bcubed_P', 'bcubed_F1', 'ceafe_R', 'ceafe_P', 'ceafe_F1',
       'conll2012_R', 'conll2012_P', 'conll2012_F1', 'nec_F1',
       'nec_per_chains_missed', 'nec_name_F1', 'nec_pronoun_F1',
       'nec_nominal_F1', 'mention_P', 'mention_R', 'mention_F1'],
      dtype='object')

In [48]:
results_df["min_speaker_sim"]

0       0.5
1       0.5
2       0.6
3       0.6
4       0.7
       ... 
1723    0.8
1724    0.9
1725    0.9
1726    1.0
1727    1.0
Name: min_speaker_sim, Length: 1728, dtype: float64

In [50]:
results_df[results_df.use_speaker_sep & ~results_df.keep_person_sys_clusters & results_df.keep_speaker_sys_clusters & results_df.heuristic_speaker_resolution].groupby(["min_speaker_sim","max_speaker_merges"]).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,keep_only_speaker_gold_clusters,remove_singleton_gold_clusters,heuristic_pronoun_resolution,use_speaker_sep,keep_person_sys_clusters,keep_speaker_sys_clusters,heuristic_speaker_resolution,muc_R,muc_P,muc_F1,...,conll2012_P,conll2012_F1,nec_F1,nec_per_chains_missed,nec_name_F1,nec_pronoun_F1,nec_nominal_F1,mention_P,mention_R,mention_F1
min_speaker_sim,max_speaker_merges,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
0.5,0,False,False,False,True,False,True,True,0.797076,0.86562,0.829873,...,0.651762,0.595051,0.418594,40.555556,0.416055,0.538954,0.375423,0.886598,0.808334,0.845362
0.5,1,False,False,False,True,False,True,True,0.797076,0.86562,0.829873,...,0.651762,0.595051,0.418594,40.555556,0.416055,0.538954,0.375423,0.886598,0.808334,0.845362
0.5,2,False,False,False,True,False,True,True,0.798818,0.864341,0.830223,...,0.644091,0.582553,0.404146,38.888889,0.403973,0.526561,0.345846,0.886598,0.808334,0.845362
0.5,3,False,False,False,True,False,True,True,0.798469,0.863582,0.829686,...,0.652644,0.586934,0.416359,35.555556,0.41464,0.522568,0.313526,0.886598,0.808334,0.845362
0.5,4,False,False,False,True,False,True,True,0.799166,0.863247,0.829908,...,0.641932,0.576192,0.404162,35.555556,0.403874,0.509894,0.329167,0.886598,0.808334,0.845362
0.5,5,False,False,False,True,False,True,True,0.799559,0.863291,0.830143,...,0.646237,0.577111,0.404162,35.555556,0.403874,0.509894,0.329167,0.886598,0.808334,0.845362
0.6,0,False,False,False,True,False,True,True,0.797424,0.868937,0.8316,...,0.658382,0.604041,0.466964,38.888889,0.45048,0.605705,0.458845,0.886598,0.808334,0.845362
0.6,1,False,False,False,True,False,True,True,0.797424,0.868937,0.8316,...,0.658382,0.604041,0.466964,38.888889,0.45048,0.605705,0.458845,0.886598,0.808334,0.845362
0.6,2,False,False,False,True,False,True,True,0.797424,0.867957,0.831146,...,0.656793,0.601958,0.466469,38.888889,0.449812,0.605705,0.48652,0.886598,0.808334,0.845362
0.6,3,False,False,False,True,False,True,True,0.797773,0.867342,0.831048,...,0.656777,0.600841,0.466497,38.888889,0.449713,0.605705,0.486433,0.886598,0.808334,0.845362


In [52]:
len(results_df)

1728

In [54]:
results_df.keep_only_speaker_gold_clusters.unique()

array([False])

In [55]:
results_df.remove_singleton_gold_clusters.unique()

array([False])

In [57]:
results_df.heuristic_pronoun_resolution.unique()

array([False])

In [77]:
pd.set_option("display.max_rows", 1000)
results_df.groupby(["min_speaker_sim", "max_speaker_merges", "use_speaker_sep", "keep_person_sys_clusters", "keep_speaker_sys_clusters", "heuristic_speaker_resolution"]).mean()[["conll2012_F1", "muc_F1", "bcubed_F1", "ceafe_F1"]].sort_values(by="conll2012_F1", ascending=False)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,conll2012_F1,muc_F1,bcubed_F1,ceafe_F1
min_speaker_sim,max_speaker_merges,use_speaker_sep,keep_person_sys_clusters,keep_speaker_sys_clusters,heuristic_speaker_resolution,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
0.9,1,True,False,True,True,0.626821,0.832141,0.632443,0.41588
0.9,2,True,False,True,True,0.626821,0.832141,0.632443,0.41588
0.9,0,True,False,True,True,0.626821,0.832141,0.632443,0.41588
0.9,3,True,False,True,True,0.626821,0.832141,0.632443,0.41588
0.9,4,True,False,True,True,0.626821,0.832141,0.632443,0.41588
0.9,5,True,False,True,True,0.626821,0.832141,0.632443,0.41588
1.0,1,True,False,True,True,0.625046,0.832517,0.632622,0.409998
1.0,4,True,False,True,True,0.625046,0.832517,0.632622,0.409998
1.0,0,True,False,True,True,0.625046,0.832517,0.632622,0.409998
1.0,5,True,False,True,True,0.625046,0.832517,0.632622,0.409998
