In [1]:
from collections import defaultdict
import jsonlines
import numpy as np
import re
import torch
import os
import itertools
import tqdm

In [2]:
data_dir = os.path.join(os.getenv("DATA_DIR"), "mica_text_coref") # type: ignore
print(data_dir)

/proj/sbaruah/data/mica_text_coref


# Spans from word clusters

In [2]:
pt = torch.load(os.path.join(data_dir, "movie_coref/results/coreference/baselines/preprocess_none.genre_bc.split_5120.overlap_512.train_wl.pt"))

In [16]:
with jsonlines.open(os.path.join(data_dir, "movie_coref/results/coreference/baselines/preprocess_none.genre_bc.split_5120.overlap_512.train_wl.jsonlines"), mode = "r") as reader:
    docs = {doc["document_id"]: doc for doc in reader} # type: ignore

In [12]:
print(docs["bc_avengers_endgame_1"].keys())
print(pt["bc_avengers_endgame_1"].keys())

dict_keys(['movie', 'rater', 'token', 'pos', 'ner', 'parse', 'speaker', 'document_id', 'cased_words', 'offset', 'clusters', 'sent_offset', 'sent_id', 'word_clusters', 'span_clusters'])
dict_keys(['coref_scores', 'coref_y', 'top_indices', 'word_clusters', 'span_clusters', 'span_scores', 'span_y'])


In [5]:
for doc_id, pt_doc in pt.items():
    word_clusters = pt_doc["word_clusters"]
    span_clusters = pt_doc["span_clusters"]
    heads = [word for cluster in word_clusters for word in cluster]
    spans = [span for cluster in span_clusters for span in cluster]
    print(f"doc_id = {doc_id}, {len(heads)} heads, {len(spans)} spans")
    for head, span in zip(heads, spans):
        if head < span[0] or head >= span[1]:
            print(head, span)

doc_id = bc_avengers_endgame_1, 922 heads, 922 spans
doc_id = bc_avengers_endgame_2, 957 heads, 957 spans
doc_id = bc_avengers_endgame_3, 929 heads, 929 spans
doc_id = bc_avengers_endgame_4, 1006 heads, 1006 spans
doc_id = bc_avengers_endgame_5, 1001 heads, 1001 spans
doc_id = bc_avengers_endgame_6, 984 heads, 984 spans
doc_id = bc_avengers_endgame_7, 956 heads, 956 spans
doc_id = bc_avengers_endgame_8, 683 heads, 683 spans
doc_id = bc_dead_poets_society_1, 1053 heads, 1053 spans
doc_id = bc_dead_poets_society_2, 990 heads, 990 spans
doc_id = bc_dead_poets_society_3, 1033 heads, 1033 spans
doc_id = bc_dead_poets_society_4, 1068 heads, 1068 spans
doc_id = bc_dead_poets_society_5, 1152 heads, 1152 spans
doc_id = bc_dead_poets_society_6, 725 heads, 725 spans
doc_id = bc_john_wick_1, 776 heads, 776 spans
doc_id = bc_john_wick_2, 786 heads, 786 spans
doc_id = bc_john_wick_3, 841 heads, 841 spans
doc_id = bc_john_wick_4, 813 heads, 813 spans
doc_id = bc_john_wick_5, 711 heads, 711 spans
doc_

# Merge coref scores

In [20]:
def combine_coref_scores(corefs: list[torch.Tensor], inds: list[torch.Tensor], overlap_lens: list[int], strategy: str) -> tuple[torch.Tensor, torch.Tensor]:
    """Combine corefs and inds into a single coref and ind tensor.
    
    Args:
        corefs: list[tensor[*, k + 1]]
        inds: list[tensor[*, k]]
        overlap_lens: list[int]
        strategy: Can be one of "before", "after", "average", "max", "min", or "none"
        
    Return:
        coref: [n, 2k + 1]
        top_indices: [n, 2k]
    """
    # Assertions
    assert len(corefs) > 0, "Number of coref tensors should be atleast 1"
    assert len(corefs) == len(inds), "Number of coref tensors should equal number of indices tensors"
    if len(corefs) == 1: return corefs[0], inds[0]
    assert len(overlap_lens) == len(corefs) - 1, "Number of overlap lengths should equal one less than the number of coref tensors"

    # Intialize
    n = sum([len(coref) - overlap_len for coref, overlap_len in zip(corefs[:-1], overlap_lens)]) + len(corefs[-1])
    k = inds[0].shape[1]
    device = corefs[0].device
    coref = torch.full((n, 2*k), fill_value=-torch.inf, device=device)
    ind = torch.full((n, 2*k), fill_value=-1, device=device)
    coref_start, coref_end = 0, 0
    overlap_lens.extend([0, 0])

    # Combine
    for i in range(len(corefs)):
        assert len(corefs[i]) - overlap_lens[i - 1] - overlap_lens[i] > 0, "Atmost two segments should overlap"
        coref_start, coref_end = coref_end, coref_end + len(corefs[i]) - overlap_lens[i - 1] - overlap_lens[i]
        start, end = overlap_lens[i - 1], len(corefs[i]) - overlap_lens[i]

        # Non-overlapping
        coref[coref_start: coref_end, :k] = corefs[i][start: end, 1:]
        ind[coref_start: coref_end, :k] = inds[i][start: end]

        # Overlapping
        coref_start, coref_end = coref_end, coref_end + overlap_lens[i]
        if strategy != "none" and i < len(corefs) - 1:
            for j in range(overlap_lens[i]):
                heads_x, heads_y = inds[i][end + j].tolist(), inds[i + 1][j].tolist()
                scores_x, scores_y = corefs[i][end + j, 1:].tolist(), corefs[i + 1][j, 1:].tolist()
                head_to_score = {h: s for h, s in zip(heads_x, scores_x) if s != -torch.inf}
                for h, s in zip(heads_y, scores_y):
                    if s != -torch.inf:
                        if h in head_to_score:
                            if strategy == "after": head_to_score[h] = s
                            elif strategy == "mean": head_to_score[h] = 0.5 * (head_to_score[h] + s)
                            elif strategy == "max": head_to_score[h] = max(head_to_score[h], s)
                            elif strategy == "min": head_to_score[h] = min(head_to_score[h], s)
                        else: head_to_score[h] = s
                for l, (h, s) in enumerate(head_to_score.items()):
                    coref[coref_start + j, l] = s
                    ind[coref_start + j, l] = h
        else:
            coref[coref_start: coref_end, :k] = corefs[i][end:, 1:]
            ind[coref_start: coref_end, :k] = inds[i][end:]

    # Add dummy
    dummy = torch.zeros((n, 1), device=coref.device)
    coref = torch.cat((dummy, coref), dim=1)
    return coref, ind

# Get clusters from fused predictions

In [10]:
cache = {}
strategy_arr = ["before", "after", "average", "max", "min", "none"]
strategy = "max"

In [13]:
def get_scores_indices_heads(pt: dict, offset: tuple[int, int]) -> tuple[torch.Tensor, torch.Tensor, dict[int, tuple[int, int]]]:
    coref, ind, word_clusters, span_cluster = pt["coref_scores"], pt["top_indices"], pt["word_clusters"], pt["span_clusters"]
    ind = ind + offset[0]
    heads = [word + offset[0] for cluster in word_clusters for word in cluster]
    spans = [(p + offset[0], q + offset[0]) for cluster in span_clusters for p, q in cluster]
    head2span = {head: span for head, span in zip(heads, spans)}
    return coref, ind, head2span

In [23]:
class GraphNode:
    def __init__(self, node_id: int):
        self.id = node_id
        self.links: set[GraphNode] = set()
        self.visited = False

    def link(self, another: "GraphNode"):
        self.links.add(another)
        another.links.add(self)

    def __repr__(self) -> str:
        return str(self.id)

In [25]:
def clusterize(scores: torch.Tensor, top_indices: torch.Tensor):
    antecedents = scores.argmax(dim=1) - 1
    not_dummy = antecedents >= 0
    coref_span_heads = torch.arange(0, len(scores))[not_dummy]
    antecedents = top_indices[coref_span_heads, antecedents[not_dummy]]

    nodes = [GraphNode(i) for i in range(len(scores))]
    for i, j in zip(coref_span_heads.tolist(), antecedents.tolist()):
        nodes[i].link(nodes[j])
        assert nodes[i] is not nodes[j]

    clusters = []
    for node in nodes:
        if len(node.links) > 0 and not node.visited:
            cluster = []
            stack = [node]
            while stack:
                current_node = stack.pop()
                current_node.visited = True
                cluster.append(current_node.id)
                stack.extend(link for link in current_node.links if not link.visited)
            assert len(cluster) > 1
            clusters.append(sorted(cluster))
    return sorted(clusters)

In [28]:
preprocess_arr = ["none", "addsays", "nocharacters"]
genre_arr = ["bc", "bn", "mz", "nw", "pt", "tc", "wb"]
split_len_arr = [2048, 3072, 4096, 5120]
overlap_len_arr = [128, 256, 512]
n_settings = len(preprocess_arr) * len(genre_arr) * len(split_len_arr) * len(overlap_len_arr)
strategy = "average"

for preprocess, genre, split_len, overlap_len in itertools.product(preprocess_arr, genre_arr, split_len_arr, overlap_len_arr):
    setting = (preprocess, genre, split_len, overlap_len)
    print(f"preprocess={preprocess} genre={genre} split_len={split_len} overlap_len={overlap_len}\n")
    
    # Read docs
    with jsonlines.open(os.path.join(data_dir, f"movie_coref/results/coreference/baselines/preprocess_{preprocess}.genre_{genre}.split_{split_len}.overlap_{overlap_len}.train_wl.jsonlines")) as reader:
        docs = {doc["document_id"]:doc for doc in reader} # type: ignore
    pt = cache.get(setting, torch.load(os.path.join(data_dir, f"movie_coref/results/coreference/baselines/preprocess_{preprocess}.genre_{genre}.split_{split_len}.overlap_{overlap_len}.train_wl.pt")))
    cache[setting] = pt
    
    # Movie to number of parts
    movie_to_n_parts = defaultdict(int)
    for doc_id in docs.keys():
        match = re.match(r"[a-z]{2}_(\w+)_(\d+)", doc_id)
        assert match is not None, "Improperly formatted document id"
        movie = match.group(1)
        part = int(match.group(2))
        movie_to_n_parts[movie] = max(part, movie_to_n_parts[movie])
    
    # Loop over movie and parts
    for movie, n_parts in movie_to_n_parts.items():
        print(movie)
        corefs, inds, offsets, head2span = [], [], [], {}
        for i in range(1, n_parts + 1):
            offset = docs[f"{genre}_{movie}_{i}"]["offset"]
            coref, ind, _head2span = get_scores_indices_heads(pt[f"{genre}_{movie}_{i}"], offset)
            corefs.append(coref)
            inds.append(ind)
            offsets.append(offset)
            head2span.update(_head2span)
        overlap_lens = [offsets[i][1] - offsets[i + 1][0] for i in range(n_parts - 1)]
        coref_lens = [len(coref) for coref in corefs]
        print(f"{n_parts} parts, sub-document lens = {coref_lens}, overlap lens = {overlap_lens}")
        coref, ind = combine_coref_scores(corefs, inds, overlap_lens, strategy)
        print(f"merged document len = {len(coref)}, coref shape = {coref.shape}, ind shape = {ind.shape}")

        word_clusters = clusterize(coref, ind)
        span_clusters = []
        for cluster in word_clusters:
            span_cluster = []
            for head in cluster:
                if head in head2span:
                    span_cluster.append(head2span[head])
            if span_cluster:
                span_clusters.append(span_cluster)
        n_word_mentions = sum([len(cluster) for cluster in word_clusters])
        n_span_mentions = sum([len(cluster) for cluster in span_clusters])
        print(f"{n_word_mentions} word mentions, {n_span_mentions} span mentions")

        print()
    break

preprocess=none genre=bc split_len=2048 overlap_len=128

avengers_endgame
19 parts, sub-document lens = [2014, 2035, 2044, 2046, 2045, 2047, 2028, 2041, 2027, 2042, 2027, 2016, 2043, 2039, 2042, 2031, 1932, 2025, 1528], overlap lens = [108, 127, 171, 133, 125, 131, 114, 146, 112, 128, 150, 105, 135, 125, 122, 114, 80, 110]
merged document len = 35816, coref shape = torch.Size([35816, 101]), ind shape = torch.Size([35816, 100])
6737 word mentions, 6737 span mentions

dead_poets_society
14 parts, sub-document lens = [2029, 2041, 2041, 2022, 2025, 1978, 2043, 1992, 2024, 2015, 2045, 2047, 2047, 1487], overlap lens = [110, 125, 195, 171, 113, 64, 127, 83, 105, 119, 137, 158, 129]
merged document len = 26200, coref shape = torch.Size([26200, 101]), ind shape = torch.Size([26200, 100])
5435 word mentions, 5435 span mentions

john_wick
14 parts, sub-document lens = [2046, 1998, 2021, 2048, 2038, 2039, 2046, 2045, 2024, 2048, 2031, 1956, 2046, 239], overlap lens = [126, 143, 154, 132, 128, 120

In [3]:
text = "bc_avengers_endgame_19"
match = re.match(r"\w+_(\d+)", text)
match.group(1)

'19'