In [54]:
from stanza.server import CoreNLPClient
import re
from tqdm import tqdm
import numpy as np
import pandas as pd
import json

In [2]:
annotators = 'tokenize,ssplit,pos,lemma,ner,depparse,parse,coref'

client = CoreNLPClient(threads=8, annotators=annotators, output_format='serialized', memory='32G', be_quiet=True, timeout=360000000, stderr=open('error.log','w'), properties={'coref.algorithm':'neural'}, max_char_length=100000)

2020-11-20 17:14:58 INFO: Writing properties to tmp file: corenlp_server-12f9d39515304dcf.props


In [40]:
bourne_doc = client.annotate(open("/proj/sbaruah/movie/coreference/annotated-data/bourne.script.txt").read(), output_format="serialized")
print("bourne done")

basterds_doc = client.annotate(open("/proj/sbaruah/movie/coreference/annotated-data/basterds.script.txt").read(), output_format="serialized")
print("basterds done")

shawshank_doc = client.annotate(open("/proj/sbaruah/movie/coreference/annotated-data/shawshank.script.txt").read(), output_format="serialized")
print("shawshank done")

bourne done
basterds done
shawshank done


## Get gold and pred clusters

In [45]:
def get_gold_and_pred_clusters(script_name, doc = None, verbose = True):
    script = open(f"/proj/sbaruah/movie/coreference/annotated-data/{script_name}.script.txt").read()
    if doc is None:
        doc = client.annotate(script, output_format="serialized")
    annotations = pd.read_csv(f"/proj/sbaruah/movie/coreference/annotated-data/{script_name}.coref.csv", index_col = None)
    
    gold_clusters = []
    for _, df in annotations.groupby("entityLabel"):
        cluster = []
        for _, row in df.iterrows():
            cluster.append((row["begin"], row["end"]))
        cluster = sorted(cluster, key = lambda mention: mention[0])
        gold_clusters.append(cluster)
    gold_clusters = sorted(gold_clusters, key = lambda cluster: len(cluster), reverse = True)
    
    gold_clusters_text = [[script[i:j] for i, j in cluster] for cluster in gold_clusters]
    
    pred_clusters = []
    for chain in doc.corefChain:
        cluster = []
        for mention in chain.mention:
            sentence = doc.sentence[mention.sentenceIndex]
            cluster.append((sentence.token[mention.beginIndex].beginChar, sentence.token[mention.endIndex - 1].endChar))
        cluster = sorted(cluster, key = lambda mention: mention[0])
        pred_clusters.append(cluster)
    pred_clusters = sorted(pred_clusters, key = lambda cluster: len(cluster), reverse = True)
    
    pred_clusters_text = [[script[i:j] for i, j in cluster] for cluster in pred_clusters]
    
    if verbose:
        print(f"{len(gold_clusters)} gold clusters. Top 20 clusters =>")
        for gi, (gold_cluster, gold_cluster_text) in enumerate(zip(gold_clusters[:20], gold_clusters_text[:20])):
            sampled_mentions = gold_cluster_text[:20]
            print(f"{gi:3d}\t{len(gold_cluster):3d} mentions\t{sampled_mentions}")
        print()

        print(f"{len(pred_clusters)} pred clusters. Top 20 clusters =>")
        for pi, (pred_cluster, pred_cluster_text) in enumerate(zip(pred_clusters[:22], pred_clusters_text[:22])):
            sampled_mentions = pred_cluster_text[:20]
            print(f"{pi:3d}\t{len(pred_cluster):3d} mentions\t{sampled_mentions}")
        print()
        
        intersection_map = np.zeros((len(gold_clusters), len(pred_clusters)), dtype=np.int)
        gold_to_pred_map = []

        for i, gold_cluster in enumerate(gold_clusters):
            for j, pred_cluster in enumerate(pred_clusters):
                gold_cluster_set = set([(k, l) for k, l in gold_cluster])
                pred_cluster_set = set([(k, l) for k, l in pred_cluster])
                intersection_map[i, j] = len(gold_cluster_set.intersection(pred_cluster_set))    

        while intersection_map.size and np.max(intersection_map):
            r, c = np.unravel_index(np.argmax(intersection_map, axis = None), intersection_map.shape)
            gold_to_pred_map.append([r, c, np.max(intersection_map)])
            intersection_map[r] = 0
            intersection_map[:,c] = 0

        if verbose:
            for gi, pi, count in gold_to_pred_map[:10]:
                print(f"gold cluster {gi} ({len(gold_clusters[gi])} mentions) <-> pred cluster {pi} ({len(pred_clusters[pi])} mentions): {count} common mentions")
                print(f"gold ==> {gold_clusters_text[gi][:50]}")    
                print(f"pred ==> {pred_clusters_text[pi][:50]}\n\n")
            print()
        
    info = {"script": script, "name": script_name, "doc": doc, "gold": gold_clusters, "pred": pred_clusters}
    return info

In [48]:
bourne_info = get_gold_and_pred_clusters("bourne", bourne_doc)

39 gold clusters. Top 20 clusters =>
  0	425 mentions	['Bourne', 'Bourne', "BOURNE'S", 'Jason\n                  Bourne', 'he', 'Bourne', 'Bourne', 'he', 'he', 'Bourne', 'Bourne', 'Bourne', 'his', 'he', 'BOURNE', 'I', 'me', 'Bourne', 'Bourne', 'you']
  1	139 mentions	['Vosen', 'his', 'Vosen', 'Vosen', "VOSEN'S", 'His', 'He', 'Vosen', 'Vosen', 'VOSEN', 'VOSEN', 'Vosen', 'VOSEN', 'Vosen', 'VOSEN', 'VOSEN', 'I', 'Vosen', 'his', 'he']
  2	103 mentions	["LANDY'S", 'Pam', 'You', 'Landy', 'LANDY', 'LANDY', 'you', 'LANDY', 'I', 'I', "LANDY'S", "LANDY'S", 'She', 'LANDY', 'Pamela Landy', 'you', "LANDY'S", 'LANDY', 'her', 'LANDY']
  3	 72 mentions	["HIRSCH'S", "Hirsch's", 'He', 'DR. HIRSCH, 70', 'He', 'His', 'DR. HIRSCH', 'you', "HIRSCH'S", 'HIRSCH', 'he', 'he', 'HIRSCH', 'I', 'HIRSCH', 'I', 'you', 'HIRSCH', 'I', 'Hirsch']
  4	 26 mentions	['WILLS', 'Wills', 'WILLS', 'Wills', 'WILLS', 'Wills', 'Wills', 'WILLS', 'Wills', 'Wills', 'WILLS', 'Wills', 'Wills', 'Sir', 'Wills', 'WILLS', 'Wills', 'WILLS'

In [49]:
basterds_info = get_gold_and_pred_clusters("basterds", basterds_doc)

23 gold clusters. Top 20 clusters =>
  0	188 mentions	['she', 'she', 'she', "Frau Von Hammersmark's", 'Frau Von Hammersmark', 'She', 'she', 'She', 'She', 'She', 'your fraulein\n          Von Hammer', 'the fraulein of\n          the hour, UFA diva, BRIDGET VON HAMMERSMARK', 'Bridget Von Hammersmark', 'her', 'her', 'BRIDGET/GENGUS', 'Fraulein Von Hammersmark', 'she', 'BRIDGET', 'my']
  1	180 mentions	['ONE GERMAN MASTER SGT', 'MASTER SGT #1', 'SGT.POLA NEGRI', 'SGT.POLA NEGRI', 'the Master Sgt', 'The Sgt over\n          there', 'His', 'him', 'his', 'Master Sgt.Pola Negri', 'his', 'he', 'the German Master Sgt', 'his', 'the German Master Sgt', 'SGT.POLA NEGRI', 'I', 'my', 'Wilhelm', 'This handsome happy Sgt']
  2	171 mentions	['A GERMAN VOICE', 'GERMAN VOICE', 'I', 'the unknown German', 'MAJOR DEITER HELLSTROM', 'The Major', 'he', 'MAJOR HELLSTROM', 'I', 'I', 'Major', 'MAJOR HELLSTROM', 'T', 'I', 'I', 'The Gestapo Major', 'MAJOR HELLSTROM', 'you', 'MAJOR HELLSTROM', 'you']
  3	152 mentions

In [50]:
shawshank_info = get_gold_and_pred_clusters("shawshank", shawshank_doc)

44 gold clusters. Top 20 clusters =>
  0	293 mentions	["ANDY DUFRESNE, mid-20's, wire rim glasses, three-piece suit", 'He', 'his', 'His', 'He', 'He', 'him', 'He', 'He', 'his', 'He', 'He', 'his', 'His', 'He', 'He', 'His', 'He', 'his', 'he']
  1	185 mentions	['RED', 'his', 'Red', 'your', 'you', 'You', 'you', 'RED', 'I', 'my', 'I', 'I', 'I', 'him', 'RED', 'his', 'He', 'RED', 'me', 'I']
  2	 56 mentions	['HEYWOOD', 'HEYWOOD', 'my', 'you', 'Heywood', 'you', 'HEYWOOD', 'I', 'HEYWOOD', 'my', 'You', 'you', 'HEYWOOD', 'I', 'me', 'Heywood', 'HEYWOOD', 'I', 'I', 'HEYWOOD']
  3	 49 mentions	['that chubby fat-ass', 'Fat-Ass', 'Fat-Ass', 'Faaaat-Ass', 'boy', 'you', 'you', 'you', "FAT-ASS'", 'Fat-Ass', 'you', 'you', 'your', 'yours', 'Fat-Ass', 'FAT-ASS', 'I', 'I', 'FAT-ASS', 'Fat-Ass']
  4	 40 mentions	['WOMAN', 'her', 'her', 'She', 'her', 'her', 'her', 'She', 'her', 'her', 'her', 'her', 'WOMAN', 'The woman', 'WOMAN', 'I', 'your\n\t\twife', 'she', 'She', 'she']
  5	 34 mentions	['BYRON HADLEY, captai

In [51]:
bourne_map = {0:0, 1:1, 2:3, 3:2, 4:7}
basterds_map = {0:1, 1:10, 2:2, 3:8, 4:0}
shawshank_map = {0:0, 1:1, 2:5, 3:4, 4:8}

bourne_info["map"] = bourne_map
basterds_info["map"] = basterds_map
shawshank_info["map"] = shawshank_map

## Scorch Input

In [55]:
for info in [bourne_info, basterds_info, shawshank_info]:
    mention_to_id = {}
    main_gold_clusters = []
    main_pred_clusters = []
    
    for i, j in info["map"].items():
        main_gold_clusters.append(info["gold"][i])
        main_pred_clusters.append(info["pred"][j])
    
    for clusters in [main_gold_clusters, main_pred_clusters]:
        for cluster in clusters:
            for i, j in cluster:
                if (i, j) not in mention_to_id:
                    mention_to_id[(i, j)] = len(mention_to_id)
    
    info["gold_id"] = [[mention_to_id[(i, j)] for i, j in cluster] for cluster in main_gold_clusters]    
    info["pred_id"] = [[mention_to_id[(i, j)] for i, j in cluster] for cluster in main_pred_clusters]
    
    gold_json = {"type":"clusters", "clusters": dict([(i, mention_ids) for i, mention_ids in enumerate(info["gold_id"])])}    
    pred_json = {"type":"clusters", "clusters": dict([(i, mention_ids) for i, mention_ids in enumerate(info["pred_id"])])}
    
    json.dump(gold_json, open(f"/proj/sbaruah/movie/coreference/corefstanford/scorch.{info['name']}.gold.json", "w"))    
    json.dump(pred_json, open(f"/proj/sbaruah/movie/coreference/corefstanford//scorch.{info['name']}.pred.json", "w"))