In [135]:
import jsonlines
import pandas as pd
import numpy as np
from collections import Counter
import pickle
import json

## Get gold and pred clusters

In [2]:
rec = None
with jsonlines.open(f"corefhoi/data/bourne.english.512.jsonlines") as reader:
    for obj in reader:
        rec = obj
print(rec.keys())

dict_keys(['doc_key', 'tokens', 'sentences', 'speakers', 'constituents', 'ner', 'clusters', 'sentence_map', 'subtoken_map', 'pronouns'])


In [4]:
pred = pickle.load(open(f"corefhoi/data/bourne.clusters.pkl", "rb"))

In [93]:
def get_gold_and_pred_clusters(script_name, indices = [0,1,2,3,4], verbose = False):
    gold_clusters_text = []
    gold_clusters = []
    with jsonlines.open(f"corefhoi/data/{script_name}.english.512.jsonlines") as reader:
        for obj in reader:
            tokens = obj["tokens"]
            gold_clusters = obj["clusters"]
            subtoken_map = obj["subtoken_map"]
    
    gold_clusters = sorted(gold_clusters, key = lambda x: len(x), reverse = True)
    for i in range(len(gold_clusters)):
        gold_clusters[i] = sorted(gold_clusters[i], key = lambda x: x[0])
        
    for cluster in gold_clusters:
        mentions = []
        for i, j in cluster:
            token_i = subtoken_map[i]
            token_j = subtoken_map[j]
            mention = " ".join(tokens[token_i: token_j + 1])
            mentions.append(mention)
        gold_clusters_text.append(mentions)
    
    pred_clusters = pickle.load(open(f"corefhoi/data/{script_name}.clusters.pkl", "rb"))
    pred_clusters = list(pred_clusters.values())[0]
    pred_clusters = [[[i, j] for i, j in cluster] for cluster in pred_clusters]
    pred_clusters = sorted(pred_clusters, key = lambda x: len(x), reverse = True)
    for i in range(len(pred_clusters)):
        pred_clusters[i] = sorted(pred_clusters[i], key = lambda x: x[0])
    
    pred_clusters_text = []
    for cluster in pred_clusters:
        mentions = []
        for i, j in cluster:
            token_i = subtoken_map[i]
            token_j = subtoken_map[j]
            mention = " ".join(tokens[token_i: token_j + 1])
            mentions.append(mention)
        pred_clusters_text.append(mentions)
    
    if verbose:
        print(f"{len(gold_clusters)} gold clusters. Top 20 clusters =>")
        for gi, (gold_cluster, gold_cluster_text) in enumerate(zip(gold_clusters[:20], gold_clusters_text[:20])):
            sampled_mentions = gold_cluster_text[:20]
            print(f"{gi:3d}\t{len(gold_cluster):3d} mentions\t{sampled_mentions}")
        print()

        print(f"{len(pred_clusters)} pred clusters. Top 20 clusters =>")
        for pi, (pred_cluster, pred_cluster_text) in enumerate(zip(pred_clusters[:22], pred_clusters_text[:22])):
            sampled_mentions = pred_cluster_text[:20]
            print(f"{pi:3d}\t{len(pred_cluster):3d} mentions\t{sampled_mentions}")
        print()
    
    intersection_map = np.zeros((len(gold_clusters), len(pred_clusters)), dtype=np.int)
    gold_to_pred_map = []

    for i, gold_cluster in enumerate(gold_clusters):
        for j, pred_cluster in enumerate(pred_clusters):
            gold_cluster_set = set([(k, l) for k, l in gold_cluster])
            pred_cluster_set = set([(k, l) for k, l in pred_cluster])
            intersection_map[i, j] = len(gold_cluster_set.intersection(pred_cluster_set))    

    while intersection_map.size and np.max(intersection_map):
        r, c = np.unravel_index(np.argmax(intersection_map, axis = None), intersection_map.shape)
        gold_to_pred_map.append([r, c, np.max(intersection_map)])
        intersection_map[r] = 0
        intersection_map[:,c] = 0

    if verbose:
        for gi, pi, count in gold_to_pred_map[:10]:
            print(f"gold cluster {gi} ({len(gold_clusters[gi])} mentions) <-> pred cluster {pi} ({len(pred_clusters[pi])} mentions): {count} common mentions")
            print(f"gold ==> {gold_clusters_text[gi][:50]}")    
            print(f"pred ==> {pred_clusters_text[pi][:50]}\n\n")
        print()
    
    gold_to_pred_map = [gold_to_pred_map[i] for i in indices]
    
    return gold_clusters, gold_clusters_text, pred_clusters, pred_clusters_text, gold_to_pred_map

In [94]:
script_name = "basterds"

gold_clusters, gold_clusters_text, pred_clusters, pred_clusters_text, gold_to_pred_map = get_gold_and_pred_clusters(script_name, verbose = True)

23 gold clusters. Top 20 clusters =>
  0	184 mentions	['she', 'she', 'she', "Frau Von Hammersmark 's", 'Frau Von Hammersmark', 'She', 'she', 'She', 'She', 'She', 'your fraulein Von Hammer', 'the fraulein of the hour , UFA diva , BRIDGET VON HAMMERSMARK', 'Bridget Von Hammersmark', 'her', 'her', 'BRIDGET / GENGUS', 'Fraulein Von Hammersmark', 'she', 'BRIDGET', 'my']
  1	163 mentions	['ONE GERMAN MASTER SGT', 'MASTER SGT # 1', 'the Master Sgt', 'The Sgt over there', 'His', 'him', 'his', 'his', 'he', 'the German Master Sgt', 'his', 'the German Master Sgt', 'I', 'my', 'Wilhelm', 'This handsome happy Sgt', 'The German Master Sgt', 'his', 'his', 'He']
  2	157 mentions	['A GERMAN VOICE', 'GERMAN VOICE', 'I', 'the unknown German', 'MAJOR DEITER HELLSTROM', 'The Major', 'he', 'MAJOR HELLSTROM', 'I', 'I', 'Major', 'MAJOR HELLSTROM', 'T', 'I', 'I', 'The Gestapo Major', 'MAJOR HELLSTROM', 'you', 'MAJOR HELLSTROM', 'you']
  3	 97 mentions	['Hicox', 'You', 'I', 'You', 'You', 'Lt. Hicox', 'Lt. Hicox'

In [95]:
shawshank_gold_to_pred = {0:0, 1:1, 2:9, 3:4, 4:3}
bourne_gold_to_pred = {0:0, 1:6, 2:1, 3:4, 4:9}
basterds_gold_to_pred = {0:2, 1:10, 2:3, 3:0, 4:21}

In [104]:
shawshank_info = {"map": shawshank_gold_to_pred, "name": "shawshank"}
gold_clusters, gold_clusters_text, pred_clusters, pred_clusters_text, gold_to_pred_map = get_gold_and_pred_clusters("shawshank", verbose = False)
shawshank_info["gold"] = gold_clusters
shawshank_info["pred"] = pred_clusters

bourne_info = {"map": bourne_gold_to_pred, "name": "bourne"}
gold_clusters, gold_clusters_text, pred_clusters, pred_clusters_text, gold_to_pred_map = get_gold_and_pred_clusters("bourne", verbose = False)
bourne_info["gold"] = gold_clusters
bourne_info["pred"] = pred_clusters

basterds_info = {"map": basterds_gold_to_pred, "name": "basterds"}
gold_clusters, gold_clusters_text, pred_clusters, pred_clusters_text, gold_to_pred_map = get_gold_and_pred_clusters("basterds", verbose = False)
basterds_info["gold"] = gold_clusters
basterds_info["pred"] = pred_clusters

infos = [shawshank_info, bourne_info, basterds_info]

## Evaluate

In [112]:
def MUC(info):
    gold = [set([(i, j) for i, j in cluster]) for cluster in info["gold"]]
    pred = [set([(i, j) for i, j in cluster]) for cluster in info["pred"]]
    
    mention_to_gi = {}
    mention_to_pi = {}
    
    for gi, cluster in enumerate(gold):
        for mention in cluster:
            mention_to_gi[mention] = gi
            
    for pi, cluster in enumerate(pred):
        for mention in cluster:
            mention_to_pi[mention] = pi
            
    gold_partitions = []
    pred_partitions = []
    
    for cluster in gold:
        pi_set = set()
        for mention in cluster:
            if mention in mention_to_pi:
                pi_set.add(mention_to_pi[mention])
            else:
                pi_set.add(-1)
        gold_partitions.append(len(pi_set))
        
    for cluster in pred:
        gi_set = set()
        for mention in cluster:
            if mention in mention_to_gi:
                gi_set.add(mention_to_gi[mention])
            else:
                gi_set.add(-1)
        pred_partitions.append(len(gi_set))
        
    precision_numerator, precision_denominator = 0, 0
    recall_numerator, recall_denominator = 0, 0
    
    for cluster, partition in zip(pred, pred_partitions):
        precision_numerator += len(cluster) - partition
        precision_denominator += len(cluster) - 1
    
    for cluster, partition in zip(pred, gold_partitions):
        recall_numerator += len(cluster) - partition
        recall_denominator += len(cluster) - 1
        
    precision = precision_numerator/precision_denominator
    recall = recall_numerator/recall_denominator
    f1 = 2 * precision * recall / (precision + recall)
    
    return precision, recall, f1

In [118]:
def B3(info):
    gold = [set([(i, j) for i, j in cluster]) for cluster in info["gold"]]
    pred = [set([(i, j) for i, j in cluster]) for cluster in info["pred"]]
    precision_numerator, recall_numerator = 0, 0

    for gold_cluster in gold:
        for pred_cluster in pred:
            common = len(gold_cluster.intersection(pred_cluster))
            precision_numerator += common*common/(len(pred_cluster))
            recall_numerator += common*common/(len(gold_cluster))
    
    precision_denominator = sum([len(gold_cluster) for gold_cluster in gold])    
    recall_denominator = sum([len(pred_cluster) for pred_cluster in pred])
    
    precision = precision_numerator/precision_denominator
    recall = recall_numerator/recall_denominator
    f1 = 2 * precision * recall / (precision + recall)
    
    return precision, recall, f1

In [130]:
def CEAF(info):
    gold = [set([(i, j) for i, j in cluster]) for cluster in info["gold"]]
    pred = [set([(i, j) for i, j in cluster]) for cluster in info["pred"]]
    numerator = 0
    
    for gi, pi in info["map"].items():
        numerator += 2 * len(gold[gi].intersection(pred[pi]))/ (len(gold[gi]) + len(pred[pi]))

    precision = numerator/len(pred)
    recall = numerator/len(gold)
    f1 = 2 * precision * recall / (precision + recall)
    
    return precision, recall, f1

In [131]:
for info in infos:
    print(info["name"])
    info["metric"] = {}
    
    for metric in [B3, CEAF, MUC]:
        p, r, f = metric(info)
        info["metric"][metric.__name__] = {"p":p, "r":r, "f":f}
        print(f"\t{metric.__name__:5s}: precision = {p:.3f}, recall = {r:.3f}, f1 = {f:.3f}")

shawshank
	B3   : precision = 0.619, recall = 0.335, f1 = 0.435
	CEAF : precision = 0.019, recall = 0.073, f1 = 0.030
	MUC  : precision = 0.955, recall = 0.915, f1 = 0.934
bourne
	B3   : precision = 0.644, recall = 0.288, f1 = 0.398
	CEAF : precision = 0.016, recall = 0.071, f1 = 0.026
	MUC  : precision = 0.961, recall = 0.921, f1 = 0.940
basterds
	B3   : precision = 0.465, recall = 0.141, f1 = 0.217
	CEAF : precision = 0.010, recall = 0.078, f1 = 0.018
	MUC  : precision = 0.930, recall = 0.857, f1 = 0.892


## Scorch input

In [137]:
for info in infos:
    mention_to_id = {}
    
    for clusters in [info["gold"], info["pred"]]:
        for cluster in clusters:
            for i, j in cluster:
                if (i, j) not in mention_to_id:
                    mention_to_id[(i, j)] = len(mention_to_id)
    
    info["gold_id"] = [[mention_to_id[(i, j)] for i, j in cluster] for cluster in info["gold"]]    
    info["pred_id"] = [[mention_to_id[(i, j)] for i, j in cluster] for cluster in info["pred"]]
    
    gold_json = {"type":"clusters", "clusters": dict([(i, mention_ids) for i, mention_ids in enumerate(info["gold_id"])])}    
    pred_json = {"type":"clusters", "clusters": dict([(i + 5, mention_ids) for i, mention_ids in enumerate(info["pred_id"])])}
    
    json.dump(gold_json, open(f"corefhoi/data/scorch.{info['name']}.gold.json", "w"))    
    json.dump(pred_json, open(f"corefhoi/data/scorch.{info['name']}.pred.json", "w"))

## Error Analysis

In [138]:
bourne_info.keys()

dict_keys(['map', 'name', 'gold', 'pred', 'metric', 'gold_id', 'pred_id'])

In [140]:
with jsonlines.open(f"corefhoi/data/bourne.english.512.jsonlines") as reader:
    for obj in reader:
        bourne_rec = obj

In [141]:
bourne_rec.keys()

dict_keys(['doc_key', 'tokens', 'sentences', 'speakers', 'constituents', 'ner', 'clusters', 'sentence_map', 'subtoken_map', 'pronouns'])

In [142]:
bourne_info["map"]

{0: 0, 1: 6, 2: 1, 3: 4, 4: 9}

In [145]:
gold_cluster = bourne_info["gold"][0]
pred_cluster = bourne_info["pred"][0]

In [146]:
gold_cluster = set([(i, j) for i, j in gold_cluster])
pred_cluster = set([(i, j) for i, j in pred_cluster])

In [148]:
print(f"|G| = {len(gold_cluster)}, |P| = {len(pred_cluster)}, |G ⋂ P| = {len(gold_cluster.intersection(pred_cluster))}, |G - P| = {len(gold_cluster.difference(pred_cluster))}, |P - G| = {len(pred_cluster.difference(gold_cluster))}")

|G| = 421, |P| = 344, |G ⋂ P| = 311, |G - P| = 110, |P - G| = 33


In [149]:
for k, (i, j) in enumerate(gold_cluster.difference(pred_cluster)):
    ti, tj = bourne_rec["subtoken_map"][i], bourne_rec["subtoken_map"][j] + 1
    left, mention, right = " ".join(bourne_rec["tokens"][ti - 20: ti]), " ".join(bourne_rec["tokens"][ti: tj]), " ".join(bourne_rec["tokens"][tj: tj + 20])
    print(f"{k:3d}. left    = {left}")    
    print(f"{k:3d}. mention = {mention}")    
    print(f"{k:3d}. right   = {right}\n\n")

  0. left    = I guess I owe you an apology . 207 INT . HUB -- BACK ROOM 207 Vosen rushes in --
  0. mention = BOURNE
  0. right   = Is that official ? V.O. VOSEN Are we triangulating ?! OVERLAPPING LANDY No . Off the record . You know


  1. left    = his front tires spinning wildly Bourne fights to straighten his car as another CRI sedan appears and tries to box
  1. mention = Bourne 's
  1. right   = car by sliding into his path . 283 OMITTED 283 284 EXT . MANHATTAN STREET -- DAY 284 Bourne just


  2. left    = 's a bad place to pick , it 's too exposed . VOSEN He would n't have chosen it if
  2. mention = he
  2. right   = did n't have a reason . 223 EXT . MANHATTAN 223 LANDY walks east through the crowded streets . 224


  3. left    = is there . Gun pointed at his head ... The two assassins look at each other ... then Bourne lowers
  3. mention = his
  3. right   = gun ... and disappears ... ON PAZ -- as the wheels start to turn ... 287B INT . HUB 287B


  4. left    = into his 