In [1]:
from mica_text_coref.coref.movie_coref import baseline
from mica_text_coref.coref.movie_coref import data
from mica_text_coref.coref.movie_coref import rules
from mica_text_coref.coref.movie_coref import split_and_merge

import collections
import copy
import itertools
import jsonlines
import numpy as np
import os
import pandas as pd
import re
import torch
import tqdm
import typing

In [2]:
def lea(key: list[set[typing.Hashable]], response: list[set[typing.Hashable]]) -> tuple[float, float, float, float]:
    key_importance = np.array([len(key_) for key_ in key])
    response_importance = np.array([len(response_) for response_ in response])
    assert np.all(key_importance > 0), "Empty cluster in key"
    assert np.all(response_importance > 0), "Empty cluster in response"
    intersection_counts = np.zeros((len(key), len(response)), dtype=int)
    for i, key_ in enumerate(key):
        for j, response_ in enumerate(response):
            intersection_counts[i, j] = len(key_.intersection(response_))
    link = intersection_counts * (intersection_counts - 1) / 2
    singleton = (key_importance == 1).reshape(-1, 1) & (response_importance == 1).reshape(1, -1)
    link[singleton] = intersection_counts[singleton]
    key_link = np.maximum(key_importance * (key_importance - 1) / 2, 1)
    response_link = np.maximum(response_importance * (response_importance - 1) / 2, 1)
    recall_numer = (key_importance * link.sum(axis=1) / key_link).sum()
    recall_denom = key_importance.sum()
    precision_numer = (response_importance * link.sum(axis=0) / response_link).sum()
    precision_denom = response_importance.sum()
    return recall_numer, recall_denom, precision_numer, precision_denom

In [3]:
data_dir = os.path.join(os.getenv("DATA_DIR"), "mica_text_coref/movie_coref/results/coreference/baselines")
input_dir = os.path.join(os.getenv("DATA_DIR"), "mica_text_coref/movie_coref/results")

In [4]:
preprocess_arr = ["none", "nocharacters", "addsays"]
genre_arr = ["bc", "bn", "mz", "nw", "pt", "tc", "wb"]
entity_arr = ["speaker", "person", "all"]
merge_speakers_arr = [False, True]
provide_gold_mentions_arr = [False, True]
remove_gold_singletons_arr = [False, True]
header = ["preprocess", "genre", "entity", "merge_speakers", "provide_gold_mentions", "remove_gold_singletons",
           "movie", "precision", "recall", "f1"]
rows = []
settings = itertools.product(preprocess_arr, genre_arr, entity_arr, merge_speakers_arr, provide_gold_mentions_arr,
                             remove_gold_singletons_arr)
n_settings = (len(preprocess_arr) * len(genre_arr) * len(entity_arr) * len(merge_speakers_arr)
              * len(provide_gold_mentions_arr) * len(remove_gold_singletons_arr))
for preprocess, genre, entity, merge_speakers, provide_gold_mentions, remove_gold_singletons in tqdm.tqdm(
        settings, total=n_settings, unit="setting"):
    file_ = os.path.join(data_dir, f"preprocess_{preprocess}.genre_{genre}.dev_wl.jsonlines")
    with jsonlines.open(file_, "r") as reader:
        data = [doc for doc in reader]
    scores = np.zeros(4)
    for doc in data:
        gold_clusters: list[set[tuple[int, int]]] = []
        pred_clusters: list[set[tuple[int, int]]] = []
        for cluster in doc["clusters"].values():
            gold_cluster = set((begin, end) for begin, end, head in cluster)
            gold_clusters.append(gold_cluster)
        for cluster in doc["span_clusters"]:
            pred_cluster = set((begin, end - 1) for begin, end in cluster)
            pred_clusters.append(pred_cluster)

        # Merge predicted clusters by speaker names
        if merge_speakers:
            pred_clusters = rules.merge_speakers(doc["token"], doc["parse"], pred_clusters)

        # Filter predicted clusters by entity type
        if entity == "speaker":
            pred_clusters = rules.keep_speakers(doc["parse"], pred_clusters)
        elif entity == "person":
            pred_clusters = rules.keep_persons(doc["ner"], pred_clusters)

        # Remove gold clusters containing single mention
        if remove_gold_singletons:
            gold_clusters = rules.remove_singleton_clusters(gold_clusters)

        # Filter predicted mentions by gold mentions
        if provide_gold_mentions:
            gold_mentions = set([mention for cluster in gold_clusters for mention in cluster])
            pred_clusters = rules.filter_mentions(gold_mentions, pred_clusters)
        
        # If preprocess == "addsays" or "none", remove spans from gold and pred clusters that overlap with a speaker
        if preprocess == "addsays" or preprocess == "none":
            parse_arr = np.array(doc["parse"])
            clusters_arr = []
            for clusters in [gold_clusters, pred_clusters]:
                clusters_ = []
                for cluster in clusters:
                    cluster_ = set()
                    for begin, end in cluster:
                        if np.all(parse_arr[begin: end + 1] != "C"):
                            cluster_.add((begin, end))
                    if cluster_:
                        clusters_.append(cluster_)
                clusters_arr.append(clusters_)
            gold_clusters, pred_clusters = clusters_arr

        # LEA
        movie_scores = lea(gold_clusters, pred_clusters)
        recall = movie_scores[0]/(movie_scores[1] + 1e-23)
        precision = movie_scores[2]/(movie_scores[3] + 1e-23)
        f1 = 2 * recall * precision / (recall + precision + 1e-23)
        rows.append([preprocess, genre, entity, merge_speakers, provide_gold_mentions, remove_gold_singletons,
                        doc["movie"], precision, recall, f1])
        scores += movie_scores
    recall = scores[0]/(scores[1] + 1e-23)
    precision = scores[2]/(scores[3] + 1e-23)
    f1 = 2 * recall * precision / (recall + precision + 1e-23)
    rows.append([preprocess, genre, entity, merge_speakers, provide_gold_mentions, remove_gold_singletons,
                    "all", precision, recall, f1])
dev_df = pd.DataFrame(rows, columns=header)

100%|██████████| 504/504 [00:38<00:00, 13.14setting/s]


In [5]:
dev_df.shape

(2016, 10)

In [6]:
dev_df[~dev_df["provide_gold_mentions"] & ~dev_df["remove_gold_singletons"] & (dev_df["movie"] == "all")].sort_values(
            by="f1", ascending=False)

Unnamed: 0,preprocess,genre,entity,merge_speakers,provide_gold_mentions,remove_gold_singletons,movie,precision,recall,f1
1939,addsays,wb,speaker,True,False,False,all,0.653154,0.656889,0.655016
1555,addsays,mz,speaker,True,False,False,all,0.634836,0.661677,0.647979
1539,addsays,mz,speaker,False,False,False,all,0.644425,0.647715,0.646066
1923,addsays,wb,speaker,False,False,False,all,0.658277,0.633670,0.645739
1843,addsays,tc,speaker,True,False,False,all,0.627405,0.663864,0.645120
...,...,...,...,...,...,...,...,...,...,...
691,nocharacters,bc,speaker,True,False,False,all,0.470970,0.205830,0.286465
1267,nocharacters,wb,speaker,True,False,False,all,0.511666,0.191656,0.278859
1251,nocharacters,wb,speaker,False,False,False,all,0.511666,0.191656,0.278859
867,nocharacters,mz,speaker,False,False,False,all,0.481163,0.193766,0.276275


In [7]:
dev_df.to_csv(os.path.join(data_dir, "excerpts.baseline.tsv"), sep="\t", index=False)

In [23]:
preprocess_arr = ["none", "nocharacters", "addsays"]
genre_arr = ["bc", "bn", "mz", "nw", "pt", "tc", "wb"]
entity_arr = ["speaker", "person", "all"]
merge_speakers_arr = [False, True]
provide_gold_mentions_arr = [False, True]
remove_gold_singletons_arr = [False, True]
split_len_arr = [2048, 3072, 4096, 5120]
overlap_len_arr = [128, 256, 512]
strategy_arr = ["none", "before", "after", "max", "min", "average"]
header = ["preprocess", "genre", "entity", "merge_speakers", "provide_gold_mentions", "remove_gold_singletons",
          "split_len", "overlap_len", "merge_strategy", "movie", "recall_numer", "recall_denom", "precision_numer",
          "precision_denom"]
rows = []
outer_settings = itertools.product(preprocess_arr, genre_arr, split_len_arr, overlap_len_arr)
inner_settings = itertools.product(entity_arr, merge_speakers_arr, provide_gold_mentions_arr,
                                   remove_gold_singletons_arr)
n_outer_settings = len(preprocess_arr) * len(genre_arr) * len(split_len_arr) * len(overlap_len_arr)
n_inner_settings = (len(entity_arr) * len(merge_speakers_arr) * len(provide_gold_mentions_arr)
                    * len(remove_gold_singletons_arr))

for preprocess, genre, split_len, overlap_len in outer_settings:
    subdir = "regular" if preprocess == "none" else preprocess
    input_file = os.path.join(input_dir, subdir, "train_wl.jsonlines")
    data_file = os.path.join(data_dir, f"preprocess_{preprocess}.genre_{genre}.split_{split_len}.overlap_{overlap_len}"
                                        ".train_wl")
    with jsonlines.open(data_file + ".jsonlines") as reader:
        pred_docs = {doc["document_id"]: doc for doc in reader}
    pt = torch.load(data_file + ".pt", map_location="cpu")
    corpus = data.CorefCorpus(input_file)
    gold_docs = {doc.movie: doc for doc in corpus}

    movie_to_n_parts = collections.defaultdict(int)
    for doc_id in pred_docs.keys():
        match = re.match(r"[a-z]{2}_(\w+)_(\d+)", doc_id)
        assert match is not None, "Improperly formatted document id"
        movie = match.group(1)
        part = int(match.group(2))
        movie_to_n_parts[movie] = max(part, movie_to_n_parts[movie])
    
    # Loop over movie and parts
    for movie, n_parts in movie_to_n_parts.items():

        corefs, inds, offsets, head2span = [], [], [], {}
        for i in range(1, n_parts + 1):
            offset = pred_docs[f"{genre}_{movie}_{i}"]["offset"]
            coref, ind, _head2span = baseline.get_scores_indices_heads(pt[f"{genre}_{movie}_{i}"], offset)
            corefs.append(coref)
            inds.append(ind)
            offsets.append(offset)
            head2span.update(_head2span)
        overlap_lens = [offsets[i][1] - offsets[i + 1][0] for i in range(n_parts - 1)]

        for strategy in tqdm.tqdm(strategy_arr, desc=f"{preprocess}/{genre}/{split_len}/{overlap_len}:"
                                                     f"{movie}({n_parts})"):
            coref, ind = split_and_merge.combine_coref_scores(corefs, inds, overlap_lens, strategy)
            word_clusters = baseline.clusterize(coref, ind)
            span_clusters = []
            for cluster in word_clusters:
                span_cluster = []
                for head in cluster:
                    if head in head2span:
                        span_cluster.append(head2span[head])
                if span_cluster:
                    span_clusters.append(span_cluster)

            gold_doc = gold_docs[movie]
            gold_clusters_ = [set([(mention.begin, mention.end) for mention in mentions])
                                for mentions in gold_doc.clusters.values()]
            pred_clusters_ = [set([(i, j - 1) for i, j in cluster]) for cluster in span_clusters]

            for entity, merge_speakers, provide_gold_mentions, remove_gold_singletons in inner_settings:
                gold_clusters = copy.deepcopy(gold_clusters_)
                pred_clusters = copy.deepcopy(pred_clusters_)

                # Merge predicted clusters by speaker names
                if merge_speakers:
                    pred_clusters = rules.merge_speakers(gold_doc.token, gold_doc.parse, pred_clusters)

                # Filter predicted clusters by entity type
                if entity == "speaker":
                    pred_clusters = rules.keep_speakers(gold_doc.parse, pred_clusters)
                elif entity == "person":
                    pred_clusters = rules.keep_persons(gold_doc.ner, pred_clusters)

                # Remove gold clusters containing single mention
                if remove_gold_singletons:
                    gold_clusters = rules.remove_singleton_clusters(gold_clusters)

                # Filter predicted mentions by gold mentions
                if provide_gold_mentions:
                    gold_mentions = set([mention for cluster in gold_clusters for mention in cluster])
                    pred_clusters = rules.filter_mentions(gold_mentions, pred_clusters)
                
                # If preprocess == "addsays" or "none", remove spans from gold and pred clusters that overlap with a speaker
                if preprocess == "addsays" or preprocess == "none":
                    parse_arr = np.array(doc["parse"])
                    clusters_arr = []
                    for clusters in [gold_clusters, pred_clusters]:
                        clusters_ = []
                        for cluster in clusters:
                            cluster_ = set()
                            for begin, end in cluster:
                                if np.all(parse_arr[begin: end + 1] != "C"):
                                    cluster_.add((begin, end))
                            if cluster_:
                                clusters_.append(cluster_)
                        clusters_arr.append(clusters_)
                    gold_clusters, pred_clusters = clusters_arr

                # LEA
                movie_scores = lea(gold_clusters, pred_clusters)
                rows.append([preprocess, genre, entity, merge_speakers, provide_gold_mentions, remove_gold_singletons,
                             split_len, overlap_len, strategy, movie] + list(movie_scores))
train_df = pd.DataFrame(rows, columns=header)

none/bc/2048/128:avengers_endgame(19): 100%|██████████| 6/6 [00:23<00:00,  3.94s/it]
none/bc/2048/128:dead_poets_society(14): 100%|██████████| 6/6 [00:12<00:00,  2.06s/it]
none/bc/2048/128:john_wick(14): 100%|██████████| 6/6 [00:11<00:00,  1.94s/it]
none/bc/2048/128:prestige(20): 100%|██████████| 6/6 [00:18<00:00,  3.00s/it]
none/bc/2048/128:quiet_place(17): 100%|██████████| 6/6 [00:28<00:00,  4.67s/it]
none/bc/2048/128:zootopia(15): 100%|██████████| 6/6 [00:12<00:00,  2.12s/it]
none/bc/2048/256:avengers_endgame(21): 100%|██████████| 6/6 [00:39<00:00,  6.64s/it]
none/bc/2048/256:dead_poets_society(15): 100%|██████████| 6/6 [00:22<00:00,  3.71s/it]
none/bc/2048/256:john_wick(15): 100%|██████████| 6/6 [00:24<00:00,  4.10s/it]
none/bc/2048/256:prestige(21): 100%|██████████| 6/6 [00:43<00:00,  7.22s/it]
none/bc/2048/256:quiet_place(18):  33%|███▎      | 2/6 [00:15<00:30,  7.56s/it]


KeyboardInterrupt: 