In [82]:
import re
import spacy
import numpy as np
import pandas as pd
from tqdm import tqdm
from copy import deepcopy
from spacy import displacy
from IPython.display import display
from sklearn.model_selection import ParameterGrid
from allennlp.predictors.predictor import Predictor
from evaluate_coreference import evaluate_coreference
from collections import defaultdict, Counter, OrderedDict
from evaluate_by_joining_elements import evaluate_coreference_by_joining_elements

In [2]:
predictor = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/coref-spanbert-large-2020.02.27.tar.gz")

Did not use initialization regex that was passed: _context_layer._module.weight_hh.*
Did not use initialization regex that was passed: _context_layer._module.weight_ih.*


In [4]:
basterds_result = evaluate_coreference_by_joining_elements("data/annotation/basterds.script_parsed.txt", "data/annotation/basterds.coref.mapped.csv", -1, use_speaker_sep=True, coreference_model=predictor)

loading spacy model


  2%|▏         | 10/591 [00:00<00:06, 91.79it/s]

spacy tokenization of screenplay elements


100%|██████████| 591/591 [00:06<00:00, 89.44it/s] 


finding global gold mention positions
	1008 gold mentions
	988 (98.01587301587301%) gold mentions found after parse
	980 (97.22222222222223%) gold mentions' spacy tokenization span found
finding gold clusters
23 gold clusters
finding sys clusters
	using 'says' after character names
	allennlp coreference resolution




40 sys clusters


MUC  : P = 0.7876 R = 0.6782 F1 = 0.7288
B3   : P = 0.5691 R = 0.2532 F1 = 0.3505
CEAFe: P = 0.2100 R = 0.3653 F1 = 0.2667
CoNLL 2012 score: 0.4487


In [5]:
basterds_result.keys()

dict_keys(['evaluation', 'gold_clusters', 'sys_clusters', 'coref_dataframe', 'document', 'mention_tags'])

In [9]:
basterds_evaluation = basterds_result["evaluation"]
basterds_gold_entity_to_cluster = basterds_result["gold_clusters"]
basterds_sys_clusters = basterds_result["sys_clusters"]
basterds_coref_df = basterds_result["coref_dataframe"]
basterds_document = basterds_result["document"]
basterds_mention_tags = basterds_result["mention_tags"]

In [10]:
spacy_nlp = spacy.load("en_core_web_sm")
spacy_basterds_document = spacy_nlp(basterds_document)

In [11]:
len(spacy_basterds_document), len(basterds_mention_tags)

(7828, 7828)

In [13]:
evaluate_coreference(basterds_gold_entity_to_cluster.values(), basterds_sys_clusters)

MUC  : P = 0.7876 R = 0.6782 F1 = 0.7288
B3   : P = 0.5691 R = 0.2532 F1 = 0.3505
CEAFe: P = 0.2100 R = 0.3653 F1 = 0.2667
CoNLL 2012 score: 0.4487


{'muc': {'R': 0.6781609195402298,
  'P': 0.787621359223301,
  'F1': 0.7288040426726559},
 'bcubed': {'R': 0.25324783617337143,
  'P': 0.5691203940018071,
  'F1': 0.3505206135514073},
 'ceafe': {'R': 0.36529256819057865,
  'P': 0.21004322670958273,
  'F1': 0.2667215577264543},
 'conll2012': {'R': 0.4322337746347267,
  'P': 0.5222616599782303,
  'F1': 0.4486820713168392}}

In [14]:
len(basterds_sys_clusters)

40

In [15]:
len(basterds_gold_entity_to_cluster)

23

In [25]:
for p, sys_cluster in enumerate(basterds_sys_clusters):
    for i, j in sys_cluster:
        k = spacy_basterds_document[i].idx
        l = spacy_basterds_document[j].idx + len(spacy_basterds_document[j])
        mention_text = re.sub("\s+", " ", basterds_document[k:l]).strip()
        mention_tags = basterds_mention_tags[i: j + 1]
        mention_ner_tags = [token.ent_type_ for token in spacy_basterds_document[i: j + 1]]
        
        if len(set(mention_tags)) == 1:
            mention_tags_text = mention_tags[0]
        else:
            mention_tags_text = ",".join(mention_tags)
            
        if len(set(mention_ner_tags)) == 1:
            mention_ner_text = mention_ner_tags[0]
        else:
            mention_ner_text = ",".join(mention_ner_tags)
        
        print(f"{p + 1:2d}. {mention_text:30s} {mention_tags_text:10s} {mention_ner_text}")
    print()

 1. ALDO                           C          ORG
 1. Aldo with his hands up         D          
 1. ALDO'S                         C          ORG,
 1. ALDO'S                         C          ORG,
 1. ALDO'S                         C          ORG
 1. I                              D          
 1. ALDO'S                         C          GPE,
 1. you                            D          
 1. ALDO                           C          ORG
 1. LT.ALDO                        D          ORG
 1. his                            D          
 1. LT.ALDO                        C          ORG
 1. my                             D          
 1. I                              D          
 1. WILLI                          C          PERSON
 1. ALDO                           C          ORG
 1. ALDO                           C          ORG
 1. Willi                          N          ORG
 1. ALDO                           C          ORG
 1. You                            D          
 1. ALDO'S     

In [26]:
basterds_coref_result = predictor.predict(document=basterds_document)

In [27]:
basterds_coref_result.keys()

dict_keys(['top_spans', 'antecedent_indices', 'predicted_antecedents', 'document', 'clusters'])

In [28]:
len(basterds_coref_result["clusters"])

171

In [29]:
len(basterds_gold_entity_to_cluster)

23

In [30]:
basterds_gold_entity_to_cluster.keys()

dict_keys(['BRIDGET VON HAMMERSMARK', 'Barmaid', 'Donowitz', 'Eric', 'FEMALE SGT #2/BEETHOVEN', 'GERMAN PRIVATE #3/MATA HARI', 'GERMAN PRIVATE #4/EDGAR WALLACE', 'GERMAN PRIVATE #5/WINNETOU', 'Goebbels', 'Hirschberg', 'LT. ALDO RAINE', 'LT. HICOX', 'Leni', 'MAJOR HELLSTROM', 'MASTER SGT #1/POLA NEGRI', 'MAXIMILIAN', 'NAZI PRIVATE', 'READER', 'STIGLITZ', "Tavern's Proprietor", 'The Fubrer', 'Utivich', 'WILHELM WICKI'])

In [31]:
basterds_multiton_gold_entity_to_cluster = dict([(entity, cluster) for entity, cluster in basterds_gold_entity_to_cluster.items() if len(cluster) > 1])

In [32]:
len(basterds_multiton_gold_entity_to_cluster)

23

In [33]:
basterds_sys_person_clusters = []

for sys_cluster in basterds_sys_clusters:
    for i, j in sys_cluster:
        is_person = any([spacy_basterds_document[k].ent_type_ == "PERSON" for k in range(i, j + 1)])
        if is_person:
            basterds_sys_person_clusters.append(sys_cluster)
            break

In [34]:
len(basterds_sys_person_clusters)

25

In [35]:
evaluate_coreference(basterds_gold_entity_to_cluster.values(), basterds_sys_person_clusters)

MUC  : P = 0.8685 R = 0.6489 F1 = 0.7428
B3   : P = 0.6184 R = 0.2489 F1 = 0.3549
CEAFe: P = 0.3065 R = 0.3332 F1 = 0.3193
CoNLL 2012 score: 0.4724


{'muc': {'R': 0.6489028213166145,
  'P': 0.8685314685314686,
  'F1': 0.742822966507177},
 'bcubed': {'R': 0.24887330294662618,
  'P': 0.6184243422803197,
  'F1': 0.35491692969055527},
 'ceafe': {'R': 0.33320147088209007,
  'P': 0.3065453532115228,
  'F1': 0.319318076262003},
 'conll2012': {'R': 0.4103258650484436,
  'P': 0.5978337213411037,
  'F1': 0.4723526574865784}}

In [59]:
configs = {"x": [1,2,3], "b": [3,4], "d": [1]}

In [62]:
def add(c, d, x, b):
    return c + d + x + b

In [63]:
for config in ParameterGrid(configs):
    print(list(config.values()))
    print(add(0,**config))

[3, 1, 1]
5
[3, 1, 2]
6
[3, 1, 3]
7
[4, 1, 1]
6
[4, 1, 2]
7
[4, 1, 3]
8


In [65]:
evaluation_df = pd.read_csv("results/coreference_evaluation.joining.csv", index_col=None)

In [67]:
evaluation_df.columns

Index(['script', 'use_speaker_sep', 'keep_only_speaker_gold_clusters',
       'remove_singleton_gold_clusters', 'keep_person_sys_clusters',
       'keep_speaker_sys_clusters', 'muc_R', 'muc_P', 'muc_F1', 'bcubed_R',
       'bcubed_P', 'bcubed_F1', 'ceafe_R', 'ceafe_P', 'ceafe_F1',
       'conll2012_R', 'conll2012_P', 'conll2012_F1', 'nec_F1',
       'nec_per_chains_missed', 'nec_name_F1', 'nec_pronoun_F1',
       'nec_nominal_F1'],
      dtype='object')

## Keep all gold clusters

In [74]:
for script, df in evaluation_df.groupby("script"):
    print(script)
    df = df[~df.keep_only_speaker_gold_clusters & ~df.remove_singleton_gold_clusters]
    df = df.sort_values(by="conll2012_F1", ascending=False)
    display_df = pd.concat([df.iloc[:,:6], df[["conll2012_F1", "nec_F1"]]], axis=1)
    display(display_df)

basterds


Unnamed: 0,script,use_speaker_sep,keep_only_speaker_gold_clusters,remove_singleton_gold_clusters,keep_person_sys_clusters,keep_speaker_sys_clusters,conll2012_F1,nec_F1
9,basterds,True,False,False,True,False,0.473133,0.510325
13,basterds,True,False,False,True,True,0.472353,0.383182
5,basterds,True,False,False,False,True,0.448682,0.420086
8,basterds,False,False,False,True,False,0.441704,0.44563
12,basterds,False,False,False,True,True,0.421643,0.320749
1,basterds,True,False,False,False,False,0.396567,0.647229
4,basterds,False,False,False,False,True,0.394878,0.379082
0,basterds,False,False,False,False,False,0.359593,0.603963


bourne


Unnamed: 0,script,use_speaker_sep,keep_only_speaker_gold_clusters,remove_singleton_gold_clusters,keep_person_sys_clusters,keep_speaker_sys_clusters,conll2012_F1,nec_F1
37,bourne,True,False,False,False,True,0.670066,0.45827
45,bourne,True,False,False,True,True,0.576487,0.390369
36,bourne,False,False,False,False,True,0.576051,0.445663
41,bourne,True,False,False,True,False,0.541976,0.390369
44,bourne,False,False,False,True,True,0.523946,0.384439
33,bourne,True,False,False,False,False,0.517408,0.531441
40,bourne,False,False,False,True,False,0.490964,0.384439
32,bourne,False,False,False,False,False,0.451429,0.517092


shawshank


Unnamed: 0,script,use_speaker_sep,keep_only_speaker_gold_clusters,remove_singleton_gold_clusters,keep_person_sys_clusters,keep_speaker_sys_clusters,conll2012_F1,nec_F1
69,shawshank,True,False,False,False,True,0.654428,0.510292
73,shawshank,True,False,False,True,False,0.586169,0.419578
77,shawshank,True,False,False,True,True,0.581947,0.305292
68,shawshank,False,False,False,False,True,0.571401,0.391707
65,shawshank,True,False,False,False,False,0.517815,0.674578
64,shawshank,False,False,False,False,False,0.465475,0.564921
72,shawshank,False,False,False,True,False,0.429565,0.317053
76,shawshank,False,False,False,True,True,0.40614,0.202605


## Keep only speaker gold clusters

In [77]:
for script, df in evaluation_df.groupby("script"):
    print(script)
    df = df[df.keep_only_speaker_gold_clusters & ~df.remove_singleton_gold_clusters]
    df = df.sort_values(by="conll2012_F1", ascending=False)
    display_df = pd.concat([df.iloc[:,:6], df[["conll2012_F1", "nec_F1"]]], axis=1)
    display(display_df)

basterds


Unnamed: 0,script,use_speaker_sep,keep_only_speaker_gold_clusters,remove_singleton_gold_clusters,keep_person_sys_clusters,keep_speaker_sys_clusters,conll2012_F1,nec_F1
29,basterds,True,True,False,True,True,0.498769,0.542173
21,basterds,True,True,False,False,True,0.466281,0.59895
28,basterds,False,True,False,True,True,0.452392,0.451152
25,basterds,True,True,False,True,False,0.436209,0.542173
20,basterds,False,True,False,False,True,0.413258,0.540896
24,basterds,False,True,False,True,False,0.408732,0.451152
17,basterds,True,True,False,False,False,0.360241,0.59895
16,basterds,False,True,False,False,False,0.319984,0.540896


bourne


Unnamed: 0,script,use_speaker_sep,keep_only_speaker_gold_clusters,remove_singleton_gold_clusters,keep_person_sys_clusters,keep_speaker_sys_clusters,conll2012_F1,nec_F1
53,bourne,True,True,False,False,True,0.753552,0.785606
52,bourne,False,True,False,False,True,0.647253,0.763994
61,bourne,True,True,False,True,True,0.646275,0.669204
60,bourne,False,True,False,True,True,0.591379,0.659038
57,bourne,True,True,False,True,False,0.589287,0.669204
56,bourne,False,True,False,True,False,0.536885,0.659038
49,bourne,True,True,False,False,False,0.477235,0.785606
48,bourne,False,True,False,False,False,0.415551,0.763994


shawshank


Unnamed: 0,script,use_speaker_sep,keep_only_speaker_gold_clusters,remove_singleton_gold_clusters,keep_person_sys_clusters,keep_speaker_sys_clusters,conll2012_F1,nec_F1
85,shawshank,True,True,False,False,True,0.711293,0.723526
93,shawshank,True,True,False,True,True,0.652021,0.46968
84,shawshank,False,True,False,False,True,0.639093,0.602626
89,shawshank,True,True,False,True,False,0.586132,0.46968
81,shawshank,True,True,False,False,False,0.476966,0.723526
92,shawshank,False,True,False,True,True,0.456209,0.3117
80,shawshank,False,True,False,False,False,0.425092,0.641087
88,shawshank,False,True,False,True,False,0.420916,0.325686


## Remove singleton gold clusters

In [78]:
for script, df in evaluation_df.groupby("script"):
    print(script)
    df = df[~df.keep_only_speaker_gold_clusters & df.remove_singleton_gold_clusters]
    df = df.sort_values(by="conll2012_F1", ascending=False)
    display_df = pd.concat([df.iloc[:,:6], df[["conll2012_F1", "nec_F1"]]], axis=1)
    display(display_df)

basterds


Unnamed: 0,script,use_speaker_sep,keep_only_speaker_gold_clusters,remove_singleton_gold_clusters,keep_person_sys_clusters,keep_speaker_sys_clusters,conll2012_F1,nec_F1
11,basterds,True,False,True,True,False,0.473133,0.510325
15,basterds,True,False,True,True,True,0.472353,0.383182
7,basterds,True,False,True,False,True,0.448682,0.420086
10,basterds,False,False,True,True,False,0.441704,0.44563
14,basterds,False,False,True,True,True,0.421643,0.320749
3,basterds,True,False,True,False,False,0.396567,0.647229
6,basterds,False,False,True,False,True,0.394878,0.379082
2,basterds,False,False,True,False,False,0.359593,0.603963


bourne


Unnamed: 0,script,use_speaker_sep,keep_only_speaker_gold_clusters,remove_singleton_gold_clusters,keep_person_sys_clusters,keep_speaker_sys_clusters,conll2012_F1,nec_F1
39,bourne,True,False,True,False,True,0.703688,0.611027
38,bourne,False,False,True,False,True,0.604069,0.594218
47,bourne,True,False,True,True,True,0.596911,0.520492
43,bourne,True,False,True,True,False,0.55422,0.520492
46,bourne,False,False,True,True,True,0.544755,0.512585
35,bourne,True,False,True,False,False,0.519526,0.708588
42,bourne,False,False,True,True,False,0.504075,0.512585
34,bourne,False,False,True,False,False,0.452243,0.689456


shawshank


Unnamed: 0,script,use_speaker_sep,keep_only_speaker_gold_clusters,remove_singleton_gold_clusters,keep_person_sys_clusters,keep_speaker_sys_clusters,conll2012_F1,nec_F1
71,shawshank,True,False,True,False,True,0.682106,0.566991
79,shawshank,True,False,True,True,True,0.60603,0.339213
75,shawshank,True,False,True,True,False,0.604243,0.466197
70,shawshank,False,False,True,False,True,0.591465,0.43523
67,shawshank,True,False,True,False,False,0.522452,0.749531
66,shawshank,False,False,True,False,False,0.468845,0.62769
74,shawshank,False,False,True,True,False,0.444488,0.352281
78,shawshank,False,False,True,True,True,0.421801,0.225117


## Visualize sys clusters

In [85]:
for token in spacy_basterds_document:
    print(token.ent)

In [90]:
token

.

In [91]:
token.ent_type

0

In [92]:
token.ent_type_

''

In [95]:
token.ent_id

0

In [96]:
token.ent_id_

''

In [93]:
coref_spacy_basterds_document = deepcopy(spacy_basterds_document)

for token in coref_spacy_basterds_document:
    token.ent_type = 0
    
for i, sys_cluster in enumerate(basterds_sys_clusters):
    for j, k in sys_cluster:
        for l in range(j, k + 1):
            token = coref_spacy_basterds_document[l]
            token.ent_type = i + 1
            token.ent_type_ = f"E{i + 1}"

In [94]:
displacy.render(coref_spacy_basterds_document, style="ent")

In [97]:
from IPython.core.display import display, HTML

In [98]:
display(HTML(open("gpr_pub/visualization/highlight.css").read()))
display(HTML(open("gpr_pub/visualization/highlight.js").read()))

In [99]:
from gpr_pub import visualization

In [104]:
import importlib

In [113]:
importlib.reload(visualization)

TypeError: reload() argument must be a module

In [112]:
visualization.render(basterds_coref_result, allen=True, jupyter=True)

NameError: name 'math' is not defined