In [1]:
import re
import spacy
import importlib
import numpy as np
import pandas as pd
from tqdm import tqdm, trange
from copy import deepcopy
from spacy import displacy
from gpr_pub import visualization
from IPython.core.display import display, HTML
from sklearn.model_selection import ParameterGrid
from allennlp.predictors.predictor import Predictor
from evaluate_coreference import evaluate_coreference
from collections import defaultdict, Counter, OrderedDict
from evaluate_by_joining_elements import evaluate_coreference_by_joining_elements

import textdistance

In [2]:
display(HTML(open('gpr_pub/visualization/highlight.css').read()))
display(HTML(open('gpr_pub/visualization/highlight.js').read()))

In [3]:
predictor = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/coref-spanbert-large-2020.02.27.tar.gz")

Did not use initialization regex that was passed: _context_layer._module.weight_ih.*
Did not use initialization regex that was passed: _context_layer._module.weight_hh.*


In [4]:
shawshank_result = evaluate_coreference_by_joining_elements("data/annotation/shawshank.script_parsed.txt", "data/annotation/shawshank.coref.mapped.csv", use_speaker_sep=True, keep_speaker_sys_clusters=False, coreference_model=predictor)

loading spacy model


  0%|          | 1/525 [00:00<01:04,  8.11it/s]

spacy tokenization of screenplay elements


100%|██████████| 525/525 [00:05<00:00, 88.05it/s] 


finding global gold mention positions
	888 gold mentions
	881 (99.21%) gold mentions found after parse
	880 (99.10%) gold mentions' spacy tokenization span found
finding gold clusters
44 gold clusters
finding sys clusters
	using 'says' after character names
	allennlp coreference resolution




	spacy ner on document
176 sys clusters
MUC  : P = 0.6614 R = 0.9067 F1 = 0.7649
B3   : P = 0.4681 R = 0.7298 F1 = 0.5704
CEAFe: P = 0.1364 R = 0.5454 F1 = 0.2182
CoNLL 2012 score: 0.5178


In [5]:
_clusters = []
for sys_cluster in shawshank_result["sys_clusters"]:
    _cluster = []
    for i, j in sys_cluster:
        _cluster.append([i, j])
    _clusters.append(_cluster)

coref_result = shawshank_result["coref_result"]
coref_result["clusters"] = _clusters

In [6]:
html = visualization.render(coref_result, allen=True, jupyter=False)

In [16]:
html

'<div style="padding: 16px;"><span>THE </span><span>SHAWSHANK </span><span>REDEMPTION </span><span>\n </span><span>by </span><span>\n </span><span>Frank </span><span>Darabont </span><span>\n </span><span>Based </span><span>upon </span><span>the </span><span>story </span><span>\n </span><span key=14 class="highlight pink" depth=0 id=54 onmouseover="handleHighlightMouseOver(this)"                 onmouseout="handleHighlightMouseOut(this)" labelPosition="left">                <span class="highlight__label"><strong>54</strong></span>                <span class="highlight__content"><span>Rita </span> <span>Hayworth </span></span></span><span>and </span><span>Shawshank </span><span>Redemption </span><span>\n </span><span>by </span><span>Stephen </span><span>King </span><span>\n </span><span>1 </span><span>INT </span><span>-- </span><span key=26 class="highlight red" depth=0 id=7 onmouseover="handleHighlightMouseOver(this)"                 onmouseout="handleHighlightMouseOut(this)" labelPosit

In [15]:
HTML(html.replace("\n", "<br/>"))

In [11]:
open("visualization.html", "w").write(html)

557222

In [14]:
HTML(open("visualization.html").read().replace("\n", "<br/>"))

In [19]:
gold_entity_to_cluster = basterds_result["gold_clusters"]
sys_clusters = basterds_result["sys_clusters"]
mention_tags = basterds_result["mention_tags"]

In [20]:
len(mention_tags)

7828

In [21]:
basterds_result.keys()

dict_keys(['evaluation', 'gold_clusters', 'sys_clusters', 'coref_dataframe', 'document', 'coref_result', 'mention_tags'])

In [22]:
spacy_nlp = spacy.load("en_core_web_sm")

In [23]:
spacy_document = spacy_nlp(basterds_result["document"])

In [24]:
len(spacy_document)

7828

In [25]:
document = basterds_result["document"]

## Speaker names

In [26]:
speakers = []
speakers_idx = []

i = 0
while i < len(spacy_document):
    if mention_tags[i] == "C":
        j = i
        while mention_tags[j] == "C":
            j += 1
        begin = spacy_document[i].idx
        end = spacy_document[j - 1].idx + len(spacy_document[j - 1])
        speaker = document[begin: end]
        speakers.append(speaker)
        speakers_idx.append((i, j - 1))
        i = j
    else:
        i += 1

In [27]:
len(speakers)

257

In [28]:
len(set(speakers))

63

In [29]:
set(speakers)

{'A SUBTITLE APPEARS:',
 'ALDO',
 "ALDO'S VOICE",
 'Am I German?',
 'BACK TO BASTERDS',
 'BRIDGET',
 "BRIDGET'S VOICE",
 'BRIDGET)',
 'BRIDGET/GENGUS',
 'DAGGER.',
 'EDGAR WALLACE',
 'EDGER WALLACE I want to try.',
 'ERIC',
 'FEMALE SGT/BEETHOVEN',
 'FINGERS)',
 "FIVE NAZI'S",
 'GERMAN SGT',
 'GERMAN VOICE Then might I inquire?',
 'HELLSTROM',
 'HELLSTROMVON HAMMERSMARKWICKIHICOXSTIGLITZ',
 'HICOX)',
 'HIRSCHBERG',
 'HIRSCHBERG)',
 'I 09. WICKI and HATA HARI',
 'I guess you do.',
 'LOW in GERMAN)',
 'LOW)',
 "LT I'! I COX",
 'LT. H ICOX',
 'LT. HI COX',
 'LT. HICOX',
 'LT.ALDO',
 'LT.HICOX',
 'LT.NICOX',
 'MAJ.KING KONG',
 'MAJ.KING KONG Am I a American?',
 'MAJOR BELLSTROM',
 'MAJOR HELLSTROM',
 'MAJOR)',
 'MATA HARI',
 'NADINE, FRANCE"',
 'NAPOLEON.',
 'PFC.HIRSCHBERG',
 'R',
 'SGT.DONOWITZ',
 'SGT.POLA NEGRI',
 'SGT.STIGLITZ',
 'STIGLITZ',
 "STIGLITZ FIRES into HELLSTROM'S BALLS...",
 "STIGLITZ I'm making YOU,...",
 'STIGLITZ)',
 'THEN...',
 'The TABLE',
 'UNDERTABLE',
 'WHEN...',
 

In [30]:
mention_coref_tags = np.full(len(mention_tags), -1, dtype=int)

In [31]:
for i, sys_cluster in enumerate(sys_clusters):
    for j, k in sys_cluster:
        for l in range(j, k + 1):
            mention_coref_tags[l] = i

In [37]:
for i, (j, k) in enumerate(speakers_idx):
    speaker = speakers[i]
    coref_tags = set(mention_coref_tags[j: k + 1])
    print(f"{speaker:30s} {coref_tags}")

A SUBTITLE APPEARS:            {-1}
NADINE, FRANCE"                {26, -1}
LT.ALDO                        {-1}
LT.HICOX                       {1}
LT.ALDO                        {-1}
LT.HICOX                       {1}
LT.ALDO                        {-1}
WICKI                          {2}
LT.HICOX                       {1}
LT.NICOX                       {4}
STIGLITZ                       {3}
LT.HICOX                       {1}
LT.HICOX                       {1}
STIGLITZ                       {3}
LT.HICOX                       {1}
I guess you do.                {1, 3, -1}
LT. HI COX                     {1, -1}
LT.ALDO                        {-1}
LT.HICOX                       {1}
LT.ALDO                        {0}
LT.HICOX                       {1}
SGT.DONOWITZ                   {-1}
LT.HICOX                       {1}
SGT.DONOWITZ                   {-1}
LT.HICOX                       {1}
LT.ALDO                        {-1}
LT.HICOX                       {1}
PFC.HIRSCHBERG                 

In [39]:
textdistance.lcsseq("heAwlljo","hrelqeloo")

'hello'

In [42]:
speaker_similarity = np.zeros((len(speakers), len(speakers)))

In [44]:
for i in range(len(speakers)):
    for j in range(i + 1, len(speakers)):
        speaker_i = re.sub("\s+", " ", speakers[i]).strip().lower()
        speaker_j = re.sub("\s+", " ", speakers[j]).strip().lower()
        d = min(len(speaker_i), len(speaker_j))
        n = len(textdistance.lcsseq(speaker_i, speaker_j))
        speaker_similarity[i, j] = n/d

In [46]:
speaker_similarity

array([[0.        , 0.4       , 0.28571429, ..., 0.5       , 0.4       ,
        0.5       ],
       [0.        , 0.        , 0.28571429, ..., 0.5       , 0.4       ,
        0.5       ],
       [0.        , 0.        , 0.        , ..., 1.        , 0.        ,
        1.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        1.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ]])

In [70]:
def cluster(speakers, min_sim, max_merges):
    normalized_speakers = [re.sub("\s+", "", speaker).lower() for speaker in speakers]
    unique_speakers = list(set(normalized_speakers))
    clusters = [[0, speaker, set()] for speaker in unique_speakers]
    
    for i, speaker in enumerate(normalized_speakers):
        j = unique_speakers.index(speaker)
        clusters[j][2].add(i)
        
    clusters = sorted(clusters, key = lambda cluster: len(cluster[2]), reverse=True)
    
    i = 0
    while i < len(clusters):
        js = []
        ri = clusters[i][1]
        
        for j in range(i + 1, len(clusters)):
            rj = clusters[j][1]
            n = len(textdistance.lcsseq(ri, rj))
            d = len(ri) + len(rj)
            sim = 2*n/d
            if sim >= min_sim:
                js.append(j)
                
        cluster = clusters[i]
        merged_js = []
        
        for j in reversed(js):
            cluster[2].update(clusters[j][2])
            cluster[0] += 1
            merged_js.append(j)
            if cluster[0] >= max_merges:
                break
                
        clusters = clusters[:i] + [cluster] + [clusters[j] for j in range(i + 1, len(clusters)) if j not in merged_js]
        i += 1
        
    clusters = [cluster[2] for cluster in clusters]
    clusters = sorted(clusters, key = lambda cluster: len(cluster), reverse = True)
    speaker_clusters = [[speakers[i] for i in cluster] for cluster in clusters if len(cluster) >= 2]
    return clusters, speaker_clusters

In [48]:
speakers

['A SUBTITLE APPEARS:',
 'NADINE, FRANCE"',
 'LT.ALDO',
 'LT.HICOX',
 'LT.ALDO',
 'LT.HICOX',
 'LT.ALDO',
 'WICKI',
 'LT.HICOX',
 'LT.NICOX',
 'STIGLITZ',
 'LT.HICOX',
 'LT.HICOX',
 'STIGLITZ',
 'LT.HICOX',
 'I guess you do.',
 'LT. HI COX',
 'LT.ALDO',
 'LT.HICOX',
 'LT.ALDO',
 'LT.HICOX',
 'SGT.DONOWITZ',
 'LT.HICOX',
 'SGT.DONOWITZ',
 'LT.HICOX',
 'LT.ALDO',
 'LT.HICOX',
 'PFC.HIRSCHBERG',
 'LT.HICOX',
 'SGT.DONOWITZ',
 'LT.HICOX',
 'STIGLITZ',
 "FIVE NAZI'S",
 'WINNETOU',
 'FEMALE SGT/BEETHOVEN',
 'EDGAR WALLACE',
 'SGT.POLA NEGRI',
 'EDGAR WALLACE',
 'WINNETOU',
 'BRIDGET/GENGUS',
 'WINNETOU',
 'WINNETOU',
 'The TABLE',
 'WINNETOU',
 'The TABLE',
 'BRIDGET',
 'LT.HICOX',
 'BRIDGET',
 'WINNETOU',
 'MATA HARI',
 'BRIDGET',
 'MATA HARI',
 'BRIDGET',
 'SGT.POLA NEGRI',
 'MATA HARI',
 'EDGER WALLACE I want to try.',
 'THEN...',
 'BRIDGET',
 'LT.HICOX',
 'BRIDGET',
 'WICXI',
 'BRIDGET',
 'LT.HICOX',
 'NAPOLEON.',
 'BRIDGET',
 'LT.HICOX',
 'BRIDGET',
 'BRIDGET',
 'BACK TO BASTERDS',
 'BR

In [71]:
id_clusters, speaker_clusters = cluster(speakers, 0.6, 3)

In [72]:
speaker_clusters

[['LT.HICOX',
  'LT.HICOX',
  'LT.HICOX',
  'LT.NICOX',
  'LT.HICOX',
  'LT.HICOX',
  'LT.HICOX',
  'LT. HI COX',
  'LT.HICOX',
  'LT.HICOX',
  'LT.HICOX',
  'LT.HICOX',
  'LT.HICOX',
  'LT.HICOX',
  'LT.HICOX',
  'LT.HICOX',
  'LT.HICOX',
  'LT.HICOX',
  'LT.HICOX',
  'LT.HICOX',
  'LT.HICOX',
  'LT.HICOX',
  'LT. H ICOX',
  'LT.HICOX',
  'LT.HICOX',
  'LT.HICOX',
  'LT.HICOX',
  'LT.HICOX',
  'LT.HICOX',
  'LT.HICOX',
  'LT.HICOX',
  'LT.HICOX',
  'LT.HICOX',
  'HICOX)',
  'LT.HICOX',
  'LT.HICOX',
  'LT. HICOX',
  'LT.HICOX',
  'LT.HICOX',
  "LT I'! I COX",
  'LT.HICOX',
  'LT.HICOX',
  'LT.HICOX'],
 ['MAJOR HELLSTROM',
  'MAJOR HELLSTROM',
  'MAJOR HELLSTROM',
  'MAJOR HELLSTROM',
  'MAJOR HELLSTROM',
  'MAJOR HELLSTROM',
  'MAJOR HELLSTROM',
  'MAJOR HELLSTROM',
  'MAJOR HELLSTROM',
  'MAJOR HELLSTROM',
  'MAJOR HELLSTROM',
  'MAJOR HELLSTROM',
  'MAJOR HELLSTROM',
  'MAJOR HELLSTROM',
  'MAJOR HELLSTROM',
  'MAJOR HELLSTROM',
  'HELLSTROM',
  'MAJOR HELLSTROM',
  'MAJOR HELLSTROM

In [76]:
len(speakers), len(speakers_idx)

(257, 257)

In [77]:
speakers_idx

[(33, 36),
 (43, 46),
 (95, 95),
 (113, 113),
 (121, 121),
 (132, 132),
 (140, 140),
 (187, 187),
 (204, 204),
 (341, 341),
 (348, 348),
 (374, 374),
 (394, 394),
 (468, 468),
 (479, 479),
 (491, 495),
 (518, 521),
 (546, 546),
 (561, 561),
 (569, 569),
 (597, 597),
 (652, 652),
 (674, 674),
 (682, 682),
 (692, 692),
 (710, 710),
 (733, 733),
 (757, 757),
 (815, 815),
 (826, 826),
 (839, 839),
 (890, 890),
 (1056, 1058),
 (1307, 1307),
 (1328, 1331),
 (1337, 1338),
 (1346, 1347),
 (1363, 1364),
 (1413, 1413),
 (1439, 1441),
 (1492, 1492),
 (1530, 1530),
 (1540, 1541),
 (1546, 1546),
 (1555, 1556),
 (1637, 1637),
 (1665, 1665),
 (1682, 1682),
 (1693, 1693),
 (1747, 1748),
 (1775, 1775),
 (1791, 1792),
 (1818, 1818),
 (1835, 1836),
 (1876, 1877),
 (1902, 1908),
 (1945, 1946),
 (1981, 1981),
 (2105, 2105),
 (2121, 2121),
 (2156, 2156),
 (2163, 2163),
 (2202, 2202),
 (2282, 2283),
 (2297, 2297),
 (2313, 2313),
 (2318, 2318),
 (2453, 2453),
 (2537, 2539),
 (2574, 2574),
 (2614, 2615),
 (263

In [78]:
heuristic_coref_tags = np.full(len(mention_coref_tags), -1, dtype=int)

In [79]:
for ci, cluster in enumerate(id_clusters):
    for i in cluster:
        j, k = speakers_idx[i]
        for l in range(j, k + 1):
            heuristic_coref_tags[l] = ci

In [81]:
len(heuristic_coref_tags)

7828

In [82]:
for h, t, token in zip(heuristic_coref_tags, mention_tags, spacy_document):
    print(f"{token.text:30s} {t} {h}")

EXT                            S -1
-                              S -1
LA                             S -1
LOUISIANE                      S -1
(                              S -1
TAVERN                         S -1
)                              S -1
-                              S -1
NIGHT                          S -1

                              X -1
We                             N -1
see                            N -1
a                              N -1
small                          N -1
basement                       N -1
tavern                         N -1
,                              N -1
with                           N -1
a                              N -1
old                            N -1
rustic                         N -1
sign                           N -1
out                            N -1
front                          N -1
that                           N -1
reads                          N -1
,                              N -1
"                           

In [83]:
sys_clusters

[{(52, 52),
  (54, 54),
  (510, 510),
  (514, 514),
  (526, 526),
  (539, 539),
  (566, 566),
  (569, 569),
  (618, 618),
  (6317, 6318),
  (6326, 6327),
  (6352, 6353),
  (6445, 6446),
  (6474, 6475),
  (6599, 6600),
  (6634, 6635),
  (6653, 6653),
  (6711, 6711),
  (6768, 6769),
  (6789, 6790),
  (6812, 6812),
  (6825, 6825),
  (6910, 6910),
  (6915, 6915),
  (6919, 6919),
  (6990, 6990),
  (6993, 6994),
  (6997, 6997),
  (7011, 7011),
  (7017, 7018),
  (7038, 7039),
  (7150, 7150),
  (7157, 7157),
  (7165, 7165),
  (7186, 7190),
  (7188, 7188),
  (7208, 7208),
  (7232, 7232),
  (7262, 7262),
  (7277, 7277),
  (7339, 7339),
  (7343, 7343),
  (7349, 7349),
  (7362, 7362),
  (7367, 7367),
  (7385, 7385),
  (7400, 7400),
  (7409, 7409),
  (7421, 7421),
  (7423, 7423),
  (7425, 7425),
  (7428, 7428),
  (7432, 7432),
  (7446, 7446),
  (7455, 7455),
  (7464, 7464),
  (7467, 7467),
  (7478, 7478),
  (7491, 7491),
  (7495, 7495),
  (7499, 7499),
  (7518, 7518),
  (7547, 7547),
  (7555, 7555)

In [84]:
heuristic_clusters_dict = defaultdict(set)
i = 0

while i < len(heuristic_coref_tags):
    if heuristic_coref_tags[i] != -1:
        j = i + 1
        while heuristic_coref_tags[j] == heuristic_coref_tags[i]:
            j += 1
        heuristic_clusters_dict[heuristic_coref_tags[i]].add((i, j - 1))
        i = j
    else:
        i += 1
        
heuristic_clusters = list(heuristic_clusters_dict.values())

In [85]:
heuristic_clusters

[{(33, 36)},
 {(43, 46)},
 {(95, 95),
  (121, 121),
  (140, 140),
  (546, 546),
  (569, 569),
  (710, 710),
  (7197, 7197),
  (7232, 7232),
  (7277, 7277),
  (7349, 7349),
  (7421, 7421),
  (7495, 7495),
  (7555, 7555),
  (7659, 7659),
  (7684, 7684),
  (7697, 7697)},
 {(113, 113),
  (132, 132),
  (204, 204),
  (341, 341),
  (374, 374),
  (394, 394),
  (479, 479),
  (518, 521),
  (561, 561),
  (597, 597),
  (674, 674),
  (692, 692),
  (733, 733),
  (815, 815),
  (839, 839),
  (1665, 1665),
  (2105, 2105),
  (2202, 2202),
  (2313, 2313),
  (3087, 3087),
  (3139, 3139),
  (3492, 3493),
  (3534, 3534),
  (3559, 3559),
  (3590, 3592),
  (3615, 3615),
  (3862, 3862),
  (3917, 3922),
  (3939, 3939),
  (3982, 3982),
  (4003, 4003),
  (4981, 4981),
  (5184, 5184),
  (5231, 5231),
  (5283, 5283),
  (5388, 5388),
  (5424, 5424),
  (5453, 5453),
  (5576, 5579),
  (5854, 5854),
  (5889, 5889),
  (5917, 5917),
  (5960, 5960)},
 {(187, 187),
  (2156, 2156),
  (3237, 3237),
  (3450, 3450),
  (4604, 4

In [86]:
len(sys_clusters), len(heuristic_clusters)

(40, 39)

In [91]:
intersection_mat = np.zeros((len(sys_clusters), len(heuristic_clusters)), dtype = np.int)

In [88]:
intersection_mat

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [92]:
for i, sys_cluster in enumerate(sys_clusters):
    for j, heuristic_cluster in enumerate(heuristic_clusters):
        intersection_mat[i, j] = len(sys_cluster.intersection(heuristic_cluster))

In [94]:
intersection_mat.tolist()

[[0,
  0,
  10,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  0,
  0],
 [0,
  0,
  0,
  18,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  17,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [0,
  0,
  0,
  0,
  6,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [0,
  0,
  0,
  0,
  0,
  3,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [0,
  0,
  0,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [0,
  0,
  0,
  

In [97]:
len((intersection_mat > 0).sum(axis = 0))

39

In [98]:
(intersection_mat > 0).sum(axis = 0)

array([0, 0, 1, 8, 1, 2, 0, 1, 2, 0, 2, 1, 1, 3, 4, 1, 1, 0, 0, 0, 0, 0,
       0, 2, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 2, 0, 0])

In [99]:
(intersection_mat > 0).sum(axis = 1)

array([2, 2, 1, 1, 1, 1, 1, 1, 1, 0, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1,
       0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1])

In [104]:
len(intersection_mat.sum(axis=1))

40

In [108]:
for i in np.arange(intersection_mat.shape[0])[intersection_mat.sum(axis = 1) > 1]:
    mentions = []
    for j, k in sys_clusters[i]:
        mention = document[spacy_document[j].idx: spacy_document[k].idx + len(spacy_document[k])]
        mentions.append(mention)
    print(len(mentions), mentions)

83 ['ALDO', 'Aldo with his hands up', "ALDO'S", "ALDO'S", "ALDO'S", 'I', "ALDO'S", 'you', 'ALDO', 'LT.ALDO', 'his', 'LT.ALDO', 'my', 'I', 'WILLI', 'ALDO', 'ALDO', 'Willi', 'ALDO', 'You', "ALDO'S", 'your', 'me', 'Aldo', 'ALDO', 'my', 'Aldo', 'I', 'ALDO', 'He', 'ALDO', 'my', 'Aldo', 'you', 'Aldo', 'him', 'yours', 'my', 'Aldo', 'I', 'Willi', 'i', 'your', 'Aldo', 'I', 'Willi', 'ALDO', 'ALDO', 'Aldo', 'He', 'I', 'I', "ALDO'S", 'he', "ALDO'S", 'that man', 'he', 'I', 'Aldo\n', "ALDO'S", 'Aldo', 'you', 'Aldo', 'Lieutenant', 'you', "Aldo's", "ALDO'S", 'your', 'you', 'I', 'your', 'Willi', 'I', 'He', 'I', 'your', 'he', 'Willi', 'Aldo', 'He', "ALDO'S", 'you', 'Aldo']
127 ["Cap't", 'Major', 'I', 'the Major', 'You', 'I', 'LT.HICOX', 'LT.HICOX', 'My', 'Your', 'I', 'my', 'I', 'my', 'Your', 'I', 'MAJOR HELLSTROM', 'MAJOR HELLSTROM', 'MAJOR HELLSTROM', 'LT.HICOX', 'LT.HICOX', 'Lt.Hicox', 'my', 'MAJOR HELLSTROM', 'your', 'you', 'his', 'COX', 'Lieutenant', 'Major', 'I', 'my', 'Major', 'you', 'LT.HICOX', '

In [112]:
def merge_clusters(clusters):
    
    def find_mergeable_clusters(clusters):
        for i in range(len(clusters)):
            for j in range(len(clusters)):
                if i != j and len(clusters[i].intersection(clusters[j])) > 0:
                    return i, j
        return -1, -1
        
    while True:
        i, j = find_mergeable_clusters(clusters)
        if i != -1:
            clusters = [clusters[k] for k in range(len(clusters)) if k != i and k != j] + [clusters[i].union(clusters[j])]
        else:
            break
        
    return clusters

In [113]:
merged_clusters = merge_clusters(sys_clusters + heuristic_clusters)

In [114]:
len(merged_clusters)

43

In [116]:
evaluate_coreference(gold_entity_to_cluster.values(), sys_clusters)

MUC  : P = 0.7876 R = 0.6782 F1 = 0.7288
B3   : P = 0.5691 R = 0.2532 F1 = 0.3505
CEAFe: P = 0.2100 R = 0.3653 F1 = 0.2667
CoNLL 2012 score: 0.4487


{'muc': {'R': 0.6781609195402298,
  'P': 0.787621359223301,
  'F1': 0.7288040426726559},
 'bcubed': {'R': 0.25324783617337143,
  'P': 0.5691203940018071,
  'F1': 0.3505206135514073},
 'ceafe': {'R': 0.36529256819057865,
  'P': 0.21004322670958273,
  'F1': 0.2667215577264543},
 'conll2012': {'R': 0.4322337746347267,
  'P': 0.5222616599782303,
  'F1': 0.4486820713168392}}

In [117]:
evaluate_coreference(gold_entity_to_cluster.values(), heuristic_clusters)

MUC  : P = 0.8440 R = 0.1923 F1 = 0.3132
B3   : P = 0.7488 R = 0.0396 F1 = 0.0753
CEAFe: P = 0.1331 R = 0.2256 F1 = 0.1674
CoNLL 2012 score: 0.1853


{'muc': {'R': 0.1922675026123302,
  'P': 0.8440366972477065,
  'F1': 0.3131914893617021},
 'bcubed': {'R': 0.039630684397902335,
  'P': 0.7488367065826119,
  'F1': 0.07527745985090081},
 'ceafe': {'R': 0.22561345436190705,
  'P': 0.13305408846984262,
  'F1': 0.167390627429802},
 'conll2012': {'R': 0.15250388045737986,
  'P': 0.5753091641000537,
  'F1': 0.1852865255474683}}

In [118]:
evaluate_coreference(gold_entity_to_cluster.values(), merged_clusters)

MUC  : P = 0.7938 R = 0.7482 F1 = 0.7703
B3   : P = 0.4753 R = 0.4514 F1 = 0.4630
CEAFe: P = 0.2158 R = 0.4035 F1 = 0.2812
CoNLL 2012 score: 0.5049


{'muc': {'R': 0.7481713688610241,
  'P': 0.7937915742793792,
  'F1': 0.7703066164604627},
 'bcubed': {'R': 0.451403553304683,
  'P': 0.47529936442495946,
  'F1': 0.46304337211006336},
 'ceafe': {'R': 0.4035249846359797,
  'P': 0.21583894527040773,
  'F1': 0.2812446862614404},
 'conll2012': {'R': 0.5343666356005622,
  'P': 0.4949766279915821,
  'F1': 0.5048648916106555}}

In [123]:
len(spacy_nlp("you your yours yourself yourselves"))

5