In [25]:
from mica_text_coref.coref.seq_coref import data_util
from mica_text_coref.coref.seq_coref import data

import numpy as np
from scorch import scores
import torch
import tqdm
from transformers import LongformerTokenizer
import subprocess
import re
import tempfile

In [3]:
train_dataset = data_util.load_tensors(
    ("/home/sbaruah_usc_edu/mica_text_coref/data/tensors/"
    "longformer_seq_tensors_512/train"), device="cpu")

In [18]:
class Metric:

    def __init__(self, recall, precision) -> None:
        self.recall = recall
        self.precision = precision
        self.f1 = 0 if (self.precision == 0 and self.recall == 0) else (
            2 * self.precision * self.recall / (self.precision + self.recall))
    
    def __repr__(self) -> str:
        return (f"P = {100*self.precision:.1f}, R = {100*self.recall:.1f}, "
                f"F1 = {100*self.f1:.1f}")

class CoreferenceMetric:

    def __init__(self, muc: Metric, b3: Metric, ceafe: Metric,
     ceafm: Metric, mention: Metric) -> None:
        self.muc = muc
        self.b3 = b3
        self.ceafe = ceafe
        self.ceafm = ceafm
        self.mention = mention
    
    def __repr__(self) -> str:
        average_f1 = (self.muc.f1 + self.b3.f1 + self.ceafe.f1)/3
        desc = (f"MUC: {self.muc}\nB3: {self.b3}\nCEAFe: {self.ceafe}\n"
                f"Average F1: {100*average_f1:.1f}\nMention: {self.mention}")
        return desc

def evaluate_clusters_scorch(
    groundtruth: dict[int, list[set[data.Mention]]],
    predictions: dict[int, list[set[data.Mention]]]) -> CoreferenceMetric:
    """Evaluates the predictions against the groundtruth annotations using the
    unofficial python scorch package.

    Args:
        groundtruth: A dictionary of list of groundtruth coreference clusters
            (set of data.Mention objects) keyed by the doc id.
        predictions: A dictionary of list of predicted coreference clusters
            (set of data.Mention objects) keyed by the doc id.
    
    Return:
        CoreferenceMetric. This contains scores for MUC, B3, CEAFe, CEAFm, and
        mention.
    """
    gold_clusters: list[set[tuple[int, data.Mention]]] = []
    pred_clusters: list[set[tuple[int, data.Mention]]] = []
    gold_mentions: set[tuple[int, data.Mention]] = set()
    pred_mentions: set[tuple[int, data.Mention]] = set()
    doc_keys: set[int] = set()

    for doc_key, clusters in tqdm.tqdm(groundtruth.items(),
        total=len(groundtruth),
        desc="Collecting groundtruth clusters and mentions"):
        doc_keys.add(doc_key)
        for cluster in clusters:
            gold_cluster: set[tuple[int, data.Mention]] = set()
            for mention in cluster:
                gold_cluster.add((doc_key, mention))
                gold_mentions.add((doc_key, mention))
            gold_clusters.append(gold_cluster)
    
    for doc_key, clusters in tqdm.tqdm(predictions.items(),
        total=len(predictions),
        desc="Collecting predictions clusters and mentions"):
        if doc_key in doc_keys:
            for cluster in clusters:
                pred_cluster: set[tuple[int, data.Mention]] = set()
                for mention in cluster:
                    pred_cluster.add((doc_key, mention))
                    pred_mentions.add((doc_key, mention))
                pred_clusters.append(pred_cluster)
    
    print("Calculating MUC")
    muc_recall, muc_precision, _ = scores.muc(gold_clusters, pred_clusters)
    print("Calculating B3")
    b3_recall, b3_precision, _ = scores.b_cubed(gold_clusters, pred_clusters)
    print("Calculating CEAF-e")
    ceafe_recall, ceafe_precision, _ = scores.ceaf_e(gold_clusters,
                                                    pred_clusters)
    print("Calculating CEAF-m")
    ceafm_recall, ceafm_precision, _ = scores.ceaf_m(gold_clusters,
                                                    pred_clusters)
    n_common_mentions = len(gold_mentions.intersection(pred_mentions))
    mention_recall = 0 if len(gold_mentions) == 0 else n_common_mentions/(
        len(gold_mentions))
    mention_precision = 0 if len(pred_mentions) == 0 else n_common_mentions/(
        len(pred_mentions))

    muc = Metric(muc_recall, muc_precision)
    b3 = Metric(b3_recall, b3_precision)
    ceafe = Metric(ceafe_recall, ceafe_precision)
    ceafm = Metric(ceafm_recall, ceafm_precision)
    mention_metric = Metric(mention_recall, mention_precision)
    scorch_metric = CoreferenceMetric(muc, b3, ceafe, ceafm, mention_metric)
    return scorch_metric

def convert_tensor_to_cluster(tensor: torch.LongTensor) -> set[data.Mention]:
    """Find the set of mentions from the annotated tensor"""
    cluster: set[data.Mention] = set()
    i = 0
    while i < len(tensor):
        if tensor[i] == 1:
            j = i + 1
            while j < len(tensor) and tensor[j] == 2:
                j += 1
            mention = data.Mention(i, j - 1)
            cluster.add(mention)
            i = j
        else:
            i += 1
    return cluster

def evaluate_tensors_scorch(groundtruth: torch.LongTensor, 
    predictions: torch.LongTensor, doc_ids: torch.IntTensor,
    corpus: data.CorefCorpus | None = None
    ) -> CoreferenceMetric | tuple[CoreferenceMetric, CoreferenceMetric]:
    """Evaluate the predictions against the groundtruth annotations using the
    unofficial python scorch library. The groundtruth and predictions are
    represented by tensors. If corpus is not None, also predict against the
    clusters of the corpus.

    Args:
        groundtruth: Integer tensor annotated with groundtruth cluster
            mentions.
        predictions: Integer tensor annotated with predicted cluster mentions.
        doc_ids: Integer tensor containing doc ids of the corresponding
            groundtruth and predictions tensors.
        corpus: Original coreference corpus from which the groundtruth tensors
            was created.
    
    Return:
        CoreferenceMetric or tuple of two CoreferenceMetric.
    """
    groundtruth_doc_id_to_clusters: dict[int, list[set[data.Mention]]] = {}
    predictions_doc_id_to_clusters: dict[int, list[set[data.Mention]]] = {}

    for doc_id, gt_tensor, pred_tensor in tqdm.tqdm(zip(
        doc_ids, groundtruth, predictions), total=len(doc_ids),
            desc="Convert Tensor to Cluster"):
        gt_cluster = convert_tensor_to_cluster(gt_tensor)
        pred_cluster = convert_tensor_to_cluster(pred_tensor)
        if len(gt_cluster):
            if doc_id not in groundtruth_doc_id_to_clusters:
                groundtruth_doc_id_to_clusters[doc_id] = []
            groundtruth_doc_id_to_clusters[doc_id].append(gt_cluster)
        if len(pred_cluster):
            if doc_id not in predictions_doc_id_to_clusters:
                predictions_doc_id_to_clusters[doc_id] = []
            predictions_doc_id_to_clusters[doc_id].append(gt_cluster)
    
    coref_metric1 = evaluate_clusters_scorch(
        groundtruth_doc_id_to_clusters, predictions_doc_id_to_clusters)
    
    if corpus is not None:
        corpus_doc_id_to_clusters: dict[int, list[set[data.Mention]]] = {}
        for document in corpus.documents:
            doc_id = document.doc_id
            if len(document.clusters):
                corpus_doc_id_to_clusters[doc_id] = document.clusters
        
        coref_metric2 = evaluate_clusters_scorch(
            corpus_doc_id_to_clusters, predictions_doc_id_to_clusters)
        return coref_metric1, coref_metric2
    else:
        return coref_metric1

In [8]:
(token_ids, mention_ids, label_ids, attn_mask,
    global_attn_mask, doc_ids) = train_dataset.tensors

In [9]:
print(f"label_ids = {label_ids.shape} {label_ids.dtype} {label_ids.device}")

label_ids = torch.Size([22601, 512]) torch.int64 cpu


In [10]:
prediction_ids = label_ids.clone()

In [20]:
n = 10000
coref_metric = evaluate_tensors_scorch(label_ids[:n], prediction_ids[:n],
    doc_ids[:n])

Convert Tensor to Cluster: 100%|██████████| 10000/10000 [01:17<00:00, 128.96it/s]
Collecting groundtruth clusters and mentions: 100%|██████████| 10000/10000 [00:00<00:00, 171624.09it/s]
Collecting predictions clusters and mentions: 100%|██████████| 10000/10000 [00:00<00:00, 11469.29it/s]


Calculating MUC
Calculating B3
Calculating CEAF-e
Calculating CEAF-m


In [23]:
tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")

In [24]:
train_corpus = data.CorefCorpus(
    ("/home/sbaruah_usc_edu/mica_text_coref/data/conll-2012/"
    "gold/train.english.jsonlines"), use_ascii_transliteration=True)
seq_train_corpus = data_util.remove_overlaps(train_corpus)
longformer_seq_train_corpus = data_util.remap_spans_document_level(
    seq_train_corpus, tokenizer.tokenize)

In [55]:
def convert_to_conll(doc_key_with_part_id: str, sentences: list[list[str]],
                    clusters: list[set[data.Mention]]) -> list[str]:
    """Create conll lines from clusters.

    Args:
        doc_key_with_part_id: The doc key in the jsonlines file.
        sentences: List of sentence. Each sentence is a list of tokens (string).
        clusters: List of cluster. Each cluster is a set of data.Mention
        objects.
    
    Returns:
        List of lines in conll-format. Each line contains the word and
        coreference tag.
    """
    match = re.match(r"(.+)_([^_]+)$", doc_key_with_part_id)
    doc_key, part_id = match.group(1), match.group(2)
    total_n_tokens = sum(len(sentence) for sentence in sentences)
    max_token_length = max(len(token) for sentence in sentences
                            for token in sentence)
    coref_column = ["-" for _ in range(total_n_tokens)]
    mentions = [(mention.begin, mention.end, i + 1)
                for i, cluster in enumerate(clusters) for mention in cluster]
    non_unigram_mentions_sorted_by_begin = sorted(
        filter(lambda mention: mention[1] - mention[0] > 0, mentions),
        key=lambda mention: (mention[0], -mention[1]))
    non_unigram_mentions_sorted_by_end = sorted(
        filter(lambda mention: mention[1] - mention[0] > 0, mentions),
        key=lambda mention: (mention[1], -mention[0]))
    unigram_mentions = filter(lambda mention: mention[1] == mention[0],
                            mentions)
    
    for begin, _, cluster_index in non_unigram_mentions_sorted_by_begin:
        if coref_column[begin] == "-":
            coref_column[begin] = "(" + str(cluster_index)
        else:
            coref_column[begin] += "|(" + str(cluster_index)
    
    for begin, _, cluster_index in unigram_mentions:
        if coref_column[begin] == "-":
            coref_column[begin] = "(" + str(cluster_index) + ")"
        else:
            coref_column[begin] += "|(" + str(cluster_index) + ")"
    
    for _, end, cluster_index in non_unigram_mentions_sorted_by_end:
        if coref_column[end] == "-":
            coref_column[end] = str(cluster_index) + ")"
        else:
            coref_column[end] += "|" + str(cluster_index) + ")"
    
    conll_lines = [f"#begin document {doc_key}; part {part_id.zfill(3)}\n"]
    filler = "  -" * 7
    i = 0
    for sentence in sentences:
        for j, token in enumerate(sentence):
            line = (f"{doc_key} {part_id} {j:>2} "
                   f"{token:>{max_token_length}}{filler} {coref_column[i]}\n")
            conll_lines.append(line)
            i += 1
        conll_lines.append("\n")
    conll_lines = conll_lines[:-1] + ["#end document\n"]

    return conll_lines

def evaluate_clusters_official(official_scorer: str,
    groundtruth: dict[int, list[set[data.Mention]]],
    predictions: dict[int, list[set[data.Mention]]],
    doc_id_to_doc_key: dict[int, str],
    doc_id_to_sentences: dict[int, list[list[str]]],
    verbose=False) -> CoreferenceMetric:
    """Evaluates the predictions against the groundtruth annotations using
    the official conll-2012 perl scorer. This function will throw an error if
    any key of the groundtruth dictionary is not present in both the
    doc_id_to_doc_key and doc_id_to_sentences dictionaries.

    Args:
        official_scorer: Path to the official perl script scorer.
        groundtruth: A dictionary of list of groundtruth coreference clusters 
            (set of data.Mention objects) keyed by the doc id.
        predictions: A dictionary of list of predicted coreference clusters
            (set of data.Mention objects) keyed by the doc id.
        doc_id_to_doc_key: A map (dictionary) from doc id to doc key.
        doc_id_to_sentences: A map (dictionary) from doc id to list of 
            sentences. Each sentence is a list of string words.
        verbose: set to true for verbose output
    
    Return:
        CoreferenceMetric. This contains scores for MUC, B3, CEAFe, CEAFm, and
        mention.
    """
    gold_conll_lines, pred_conll_lines = [], []
    
    for doc_id, gold_clusters in tqdm.tqdm(groundtruth.items(),
        desc="Creating conll", total=len(groundtruth)):
        doc_key = doc_id_to_doc_key[doc_id]
        sentences = doc_id_to_sentences[doc_id]
        pred_clusters = predictions[doc_id] if doc_id in predictions else []
        
        n_tokens = sum(len(s) for s in sentences)
        if gold_clusters:
            max_gold_end = max(mention.end for cluster in gold_clusters for mention in cluster)
        else:
            max_gold_end = -1
        if pred_clusters:
            max_pred_end = max(mention.end for cluster in pred_clusters for mention in cluster)
        else:
            max_pred_end = -1
        print(doc_id, doc_key, n_tokens, max_gold_end, max_pred_end)
        
        gold_document_conll_lines = convert_to_conll(
            doc_key, sentences, gold_clusters)
        pred_document_conll_lines = convert_to_conll(
            doc_key, sentences, pred_clusters)
        gold_conll_lines.extend(gold_document_conll_lines)
        pred_conll_lines.extend(pred_document_conll_lines)
    
    with tempfile.NamedTemporaryFile(mode="w", delete=True) as gold_file, \
        tempfile.NamedTemporaryFile(mode="w", delete=True) as pred_file:
        gold_file.writelines(gold_conll_lines)
        pred_file.writelines(pred_conll_lines)

        if verbose:
            print(f"Gold file = {gold_file.name}")
            print(f"Pred file = {pred_file.name}")

        cmd = [official_scorer, "all", gold_file.name, pred_file.name,
                "none"]
        process = subprocess.Popen(cmd, stdout=subprocess.PIPE)
        stdout, stderr = process.communicate()
        process.wait()
        stdout = stdout.decode("utf-8")

        if verbose:
            if stderr is not None:
                print(stderr)
            if stdout:
                print("Official result")
                print(stdout)

        matched_tuples = re.findall(
            r"Coreference: Recall: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\s+"
            r"Precision: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\s+F1:"
            r" ([0-9.]+)%", stdout, flags=re.DOTALL)
        
        muc = Metric(float(matched_tuples[0][0]), float(matched_tuples[0][1]))
        b3 = Metric(float(matched_tuples[1][0]), float(matched_tuples[1][1]))
        ceafm = Metric(float(matched_tuples[2][0]), float(matched_tuples[2][1]))
        ceafe = Metric(float(matched_tuples[3][0]), float(matched_tuples[3][1]))
        
        mention_match = re.search(
            r"Mentions: Recall: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\s+Precision:"
            r" \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\s+F1: ([0-9.]+)%", stdout,
            flags=re.DOTALL)
        mention_metric = Metric(float(mention_match.group(1)),
                                float(mention_match.group(2)))
        official_metric = CoreferenceMetric(muc, b3, ceafe, ceafm,
                                            mention_metric)
        return official_metric

def evaluate_tensors_official(official_scorer: str,
    groundtruth: torch.LongTensor, predictions: torch.LongTensor,
    doc_ids: torch.IntTensor, corpus: data.CorefCorpus
    ) -> tuple[CoreferenceMetric, CoreferenceMetric]:
    """Evaluate the predictions against the groundtruth annotations using the
    official conll-2012 perl scorer. The groundtruth and predictions are
    represented by tensors.

    Args:
        official_scorer: Path to the official perl script scorer.
        groundtruth: Integer tensor annotated with groundtruth cluster
            mentions.
        predictions: Integer tensor annotated with predicted cluster mentions.
        doc_ids: Integer tensor containing doc ids of the corresponding
            groundtruth and predictions tensors.
        corpus: Original coreference corpus from which the groundtruth tensors
            was created.
    
    Return:
        Tuple of two CoreferenceMetric objects.
    """
    corpus_doc_id_to_clusters: dict[int, list[set[data.Mention]]] = {}
    groundtruth_doc_id_to_clusters: dict[int, list[set[data.Mention]]] = {}
    predictions_doc_id_to_clusters: dict[int, list[set[data.Mention]]] = {}
    doc_id_to_doc_key: dict[int, str] = corpus.get_doc_id_to_doc_key()
    doc_id_to_sentences: dict[int, list[list[str]]] = (
        corpus.get_doc_id_to_sentences())

    for doc_id, gt_tensor, pred_tensor in tqdm.tqdm(zip(
        doc_ids, groundtruth, predictions),
        desc="Convert tensors to cluster", total=len(doc_ids)):
        gt_cluster = convert_tensor_to_cluster(gt_tensor)
        pred_cluster = convert_tensor_to_cluster(pred_tensor)
        doc_id = doc_id.item()
        if len(gt_cluster):
            if doc_id not in groundtruth_doc_id_to_clusters:
                groundtruth_doc_id_to_clusters[doc_id] = []
            groundtruth_doc_id_to_clusters[doc_id].append(gt_cluster)
        if len(pred_cluster):
            if doc_id not in predictions_doc_id_to_clusters:
                predictions_doc_id_to_clusters[doc_id] = []
            predictions_doc_id_to_clusters[doc_id].append(gt_cluster)

    for document in corpus.documents:
        doc_id = document.doc_id
        if len(document.clusters):
            corpus_doc_id_to_clusters[doc_id] = document.clusters
    
    coref_metric1 = evaluate_clusters_official(official_scorer, 
        groundtruth_doc_id_to_clusters, predictions_doc_id_to_clusters,
        doc_id_to_doc_key, doc_id_to_sentences)
    # coref_metric2 = evaluate_clusters_official(official_scorer, 
    #     corpus_doc_id_to_clusters, predictions_doc_id_to_clusters,
    #     doc_id_to_doc_key, doc_id_to_sentences)
    return coref_metric1

In [57]:
n = 10000
evaluate_tensors_official(
    ("/home/sbaruah_usc_edu/mica_text_coref/coref/seq_coref/scorer/"
        "v8.01/scorer.pl"), label_ids[:n], prediction_ids[:n],
    doc_ids[:n], longformer_seq_train_corpus)

Convert tensors to cluster: 100%|██████████| 10000/10000 [01:10<00:00, 141.52it/s]
Creating conll:   7%|▋         | 85/1257 [00:00<00:01, 839.62it/s]

0 bc/cctv/00/cctv_0001_0 401 346 346
1 bc/cctv/00/cctv_0001_1 643 508 508
2 bc/cctv/00/cctv_0001_2 1027 510 510
3 bc/cctv/00/cctv_0001_3 340 335 335
4 bc/cctv/00/cctv_0001_4 564 509 509
5 bc/cctv/00/cctv_0001_5 679 499 499
6 bc/cctv/00/cctv_0001_6 262 238 238
7 bc/cctv/00/cctv_0001_7 942 376 376
8 bc/cctv/00/cctv_0001_8 234 228 228
9 bc/cctv/00/cctv_0002_0 256 225 225
10 bc/cctv/00/cctv_0002_1 106 98 98
11 bc/cctv/00/cctv_0002_2 434 420 420
12 bc/cctv/00/cctv_0002_3 502 488 488
13 bc/cctv/00/cctv_0002_4 362 352 352
14 bc/cctv/00/cctv_0002_5 275 252 252
15 bc/cctv/00/cctv_0002_6 411 388 388
16 bc/cctv/00/cctv_0002_7 342 332 332
17 bc/cctv/00/cctv_0002_8 657 500 500
18 bc/cctv/00/cctv_0002_9 499 490 490
19 bc/cctv/00/cctv_0002_10 405 403 403
20 bc/cctv/00/cctv_0002_11 469 450 450
21 bc/cctv/00/cctv_0002_12 478 461 461
22 bc/cctv/00/cctv_0002_13 242 228 228
23 bc/cctv/00/cctv_0002_14 484 470 470
24 bc/cctv/00/cctv_0002_15 394 373 373
25 bc/cctv/00/cctv_0002_16 326 317 317
26 bc/cctv/00/cc

Creating conll:  19%|█▉        | 244/1257 [00:00<00:01, 735.28it/s]

162 bc/msnbc/00/msnbc_0001_0 405 400 400
163 bc/msnbc/00/msnbc_0001_1 687 501 501
164 bc/msnbc/00/msnbc_0001_2 951 505 505
165 bc/msnbc/00/msnbc_0001_3 561 508 508
166 bc/msnbc/00/msnbc_0001_4 1076 505 505
167 bc/msnbc/00/msnbc_0001_5 608 475 475
168 bc/msnbc/00/msnbc_0001_6 235 227 227
169 bc/msnbc/00/msnbc_0001_7 235 201 201
170 bc/msnbc/00/msnbc_0001_8 396 391 391
171 bc/msnbc/00/msnbc_0001_9 564 508 508
172 bc/msnbc/00/msnbc_0001_10 829 454 454
173 bc/msnbc/00/msnbc_0001_11 431 422 422
174 bc/msnbc/00/msnbc_0001_12 809 450 450
175 bc/msnbc/00/msnbc_0001_13 192 166 166
176 bc/msnbc/00/msnbc_0001_14 494 488 488
177 bc/msnbc/00/msnbc_0001_15 409 397 397
178 bc/msnbc/00/msnbc_0002_0 625 374 374
179 bc/msnbc/00/msnbc_0002_1 1085 500 500
180 bc/msnbc/00/msnbc_0002_2 589 503 503
181 bc/msnbc/00/msnbc_0002_3 886 501 501
182 bc/msnbc/00/msnbc_0002_4 320 309 309
183 bc/msnbc/00/msnbc_0002_5 453 451 451
184 bc/msnbc/00/msnbc_0002_6 390 359 359
185 bc/msnbc/00/msnbc_0002_7 822 460 460
186 bc/m

Creating conll:  42%|████▏     | 531/1257 [00:00<00:00, 1217.89it/s]

316 bn/abc/00/abc_0042_0 340 338 338
317 bn/abc/00/abc_0043_0 253 251 251
318 bn/abc/00/abc_0044_0 69 54 54
319 bn/abc/00/abc_0045_0 68 60 60
320 bn/abc/00/abc_0046_0 55 53 53
321 bn/abc/00/abc_0047_0 424 420 420
322 bn/abc/00/abc_0048_0 407 405 405
323 bn/abc/00/abc_0051_0 503 501 501
324 bn/abc/00/abc_0052_0 443 441 441
325 bn/abc/00/abc_0053_0 304 300 300
327 bn/abc/00/abc_0055_0 512 510 510
328 bn/abc/00/abc_0056_0 507 505 505
329 bn/abc/00/abc_0057_0 78 71 71
330 bn/abc/00/abc_0058_0 49 37 37
331 bn/abc/00/abc_0061_0 42 27 27
332 bn/abc/00/abc_0062_0 397 395 395
333 bn/abc/00/abc_0063_0 481 476 476
334 bn/abc/00/abc_0064_0 477 470 470
335 bn/abc/00/abc_0065_0 279 258 258
336 bn/abc/00/abc_0066_0 495 493 493
337 bn/abc/00/abc_0067_0 311 307 307
338 bn/abc/00/abc_0068_0 74 62 62
339 bn/cnn/00/cnn_0001_0 294 289 289
340 bn/cnn/00/cnn_0002_0 172 150 150
341 bn/cnn/00/cnn_0003_0 107 87 87
342 bn/cnn/00/cnn_0004_0 96 92 92
343 bn/cnn/00/cnn_0005_0 276 274 274
344 bn/cnn/00/cnn_0006_0 57

Creating conll:  62%|██████▏   | 775/1257 [00:00<00:00, 1020.35it/s]

657 bn/cnn/03/cnn_0397_0 567 501 501
658 bn/cnn/03/cnn_0398_0 98 97 97
659 bn/cnn/04/cnn_0401_0 261 243 243
660 bn/cnn/04/cnn_0402_0 286 284 284
661 bn/cnn/04/cnn_0403_0 70 43 43
662 bn/cnn/04/cnn_0404_0 735 507 507
663 bn/cnn/04/cnn_0405_0 130 114 114
664 bn/cnn/04/cnn_0406_0 283 261 261
665 bn/cnn/04/cnn_0407_0 243 238 238
666 bn/cnn/04/cnn_0408_0 606 508 508
667 bn/cnn/04/cnn_0411_0 1212 469 469
668 bn/cnn/04/cnn_0412_0 66 64 64
669 bn/cnn/04/cnn_0413_0 60 51 51
670 bn/cnn/04/cnn_0414_0 193 191 191
671 bn/cnn/04/cnn_0415_0 794 509 509
672 bn/cnn/04/cnn_0416_0 244 227 227
673 bn/cnn/04/cnn_0417_0 186 175 175
674 bn/cnn/04/cnn_0418_0 116 95 95
675 bn/cnn/04/cnn_0421_0 63 53 53
676 bn/cnn/04/cnn_0422_0 318 303 303
677 bn/cnn/04/cnn_0423_0 473 465 465
678 bn/cnn/04/cnn_0424_0 846 464 464
679 bn/cnn/04/cnn_0425_0 250 247 247
680 bn/cnn/04/cnn_0426_0 449 418 418
681 bn/cnn/04/cnn_0427_0 293 280 280
682 bn/cnn/04/cnn_0428_0 258 256 256
683 bn/cnn/04/cnn_0431_0 691 501 501
684 bn/cnn/04/cnn

Creating conll:  85%|████████▍ | 1068/1257 [00:00<00:00, 1245.97it/s]

837 bn/voa/00/voa_0005_0 235 231 231
839 bn/voa/00/voa_0007_0 201 196 196
840 bn/voa/00/voa_0008_0 76 66 66
841 bn/voa/00/voa_0011_0 114 100 100
842 bn/voa/00/voa_0012_0 402 385 385
843 bn/voa/00/voa_0013_0 103 68 68
844 bn/voa/00/voa_0014_0 414 409 409
845 bn/voa/00/voa_0015_0 133 131 131
846 bn/voa/00/voa_0016_0 56 50 50
848 bn/voa/00/voa_0018_0 54 52 52
849 bn/voa/00/voa_0021_0 814 492 492
850 bn/voa/00/voa_0022_0 296 294 294
851 bn/voa/00/voa_0023_0 314 312 312
852 bn/voa/00/voa_0024_0 458 456 456
853 bn/voa/00/voa_0025_0 53 44 44
854 bn/voa/00/voa_0026_0 117 106 106
855 bn/voa/00/voa_0027_0 160 147 147
856 bn/voa/00/voa_0028_0 231 229 229
857 bn/voa/00/voa_0031_0 93 91 91
858 bn/voa/00/voa_0032_0 311 309 309
859 bn/voa/00/voa_0033_0 215 194 194
860 bn/voa/00/voa_0034_0 69 32 32
861 bn/voa/00/voa_0035_0 207 205 205
862 bn/voa/00/voa_0036_0 92 87 87
863 bn/voa/00/voa_0037_0 159 150 150
864 bn/voa/00/voa_0038_0 348 346 346
865 bn/voa/00/voa_0041_0 244 242 242
866 bn/voa/00/voa_0042_0

Creating conll: 100%|██████████| 1257/1257 [00:01<00:00, 1004.65it/s]

1140 mz/sinorama/10/ectb_1014_1 449 430 430
1141 mz/sinorama/10/ectb_1014_2 436 432 432
1142 mz/sinorama/10/ectb_1014_3 361 359 359
1143 mz/sinorama/10/ectb_1014_4 492 462 462
1144 mz/sinorama/10/ectb_1014_5 447 444 444
1145 mz/sinorama/10/ectb_1014_6 368 357 357
1146 mz/sinorama/10/ectb_1014_7 493 476 476
1147 mz/sinorama/10/ectb_1014_8 772 337 337
1148 mz/sinorama/10/ectb_1014_9 585 433 433
1149 mz/sinorama/10/ectb_1015_0 336 334 334
1150 mz/sinorama/10/ectb_1015_1 459 446 446
1151 mz/sinorama/10/ectb_1015_2 269 257 257
1152 mz/sinorama/10/ectb_1016_0 523 481 481
1153 mz/sinorama/10/ectb_1016_1 642 473 473
1154 mz/sinorama/10/ectb_1016_2 715 373 373
1155 mz/sinorama/10/ectb_1016_3 511 501 501
1156 mz/sinorama/10/ectb_1016_4 396 289 289
1157 mz/sinorama/10/ectb_1016_5 484 461 461
1158 mz/sinorama/10/ectb_1016_6 711 491 491
1159 mz/sinorama/10/ectb_1016_7 506 503 503
1160 mz/sinorama/10/ectb_1016_8 193 177 177
1161 mz/sinorama/10/ectb_1017_0 396 394 394
1162 mz/sinorama/10/ectb_1017_1 




MUC: P = 10000.0, R = 10000.0, F1 = 10000.0
B3: P = 10000.0, R = 10000.0, F1 = 10000.0
CEAFe: P = 10000.0, R = 10000.0, F1 = 10000.0
Average F1: 10000.0
Mention: P = 10000.0, R = 10000.0, F1 = 10000.0

In [35]:
label_ids[3]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 2, 2, 2,
        2, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [40]:
(token_ids[3] != 1).int()

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

In [38]:
print(longformer_seq_train_corpus.documents[3])

<mica_text_coref.coref.seq_coref.data.CorefDocument object at 0x7f4f438548b0>


In [42]:
convert_tensor_to_cluster(prediction_ids[3])

{(42,49), (51,51)}

In [46]:
torch.where(doc_ids == 3)

(tensor([38, 39, 40, 41, 42, 43, 44, 45]),)

In [49]:
(token_ids[38] == 0).sum()

tensor(0)

In [50]:
longformer_seq_train_corpus.documents[3].doc_id

3

In [53]:
longformer_seq_train_corpus.documents[3].doc_key

'bc/cctv/00/cctv_0001_3'

In [52]:
i = 0
for sentence in longformer_seq_train_corpus.get_doc_id_to_sentences()[3]:
    for token in sentence:
        print(i, token)
        i += 1

0 On
1 Ġthe
2 Ġafternoon
3 Ġof
4 ĠAugust
5 Ġ22
6 Ġ,
7 ĠPeng
8 ĠDe
9 hu
10 ai
11 Ġwas
12 Ġlistening
13 Ġto
14 Ġthe
15 Ġcombat
16 Ġoperation
17 Ġdirector
18 Ġreport
19 Ġon
20 Ġbattle
21 Ġdevelopments
22 Ġat
23 ĠEighth
24 ĠRoute
25 ĠArmy
26 Ġoperational
27 Ġheadquarters
28 Ġ.
29 ĠWhen
30 Ġasked
31 Ġabout
32 Ġthe
33 Ġactual
34 Ġcombat
35 Ġstrength
36 Ġof
37 Ġthe
38 ĠEighth
39 ĠRoute
40 ĠArmy
41 Ġ,
42 Ġthe
43 Ġcombat
44 Ġoperation
45 Ġdirector
46 Ġreported
47 Ġ:
48 ĠThere
49 Ġare
50 Ġ30
51 Ġreg
52 iments
53 Ġalong
54 Ġthe
55 ĠZheng
56 t
57 ai
58 ĠLine
59 Ġ,
60 Ġ15
61 Ġreg
62 iments
63 Ġalong
64 Ġthe
65 ĠLug
66 ou
67 ĠBridge
68 Ġ-
69 ĠHand
70 an
71 Ġsection
72 Ġof
73 Ġthe
74 ĠPing
75 han
76 Ġline
77 Ġ,
78 Ġ12
79 Ġreg
80 iments
81 Ġalong
82 Ġthe
83 ĠDat
84 ong
85 Ġ-
86 ĠHong
87 d
88 ong
89 Ġsection
90 Ġin
91 ĠTon
92 gh
93 u
94 ĠCounty
95 Ġ,
96 Ġand
97 Ġfour
98 Ġreg
99 iments
100 Ġalong
101 Ġthe
102 ĠTian
103 jin
104 Ġ-
105 ĠDe
106 zhou
107 Ġsection
108 Ġof
109 Ġthe
110 ĠJin
111 pu
112 ĠLine
1