In [44]:
scorer_path = "/root/workspace/fast-coref/coref_resources/reference-coreference-scorers/scorer.pl"
test_conll_path = "/root/workspace/sr_coref/src/benchmarking/data/radcoref_test.conll"

output_file_path = "/root/workspace/sr_coref/src/benchmarking/data/radcoref_test_pred.conll"

In [30]:
input_jsonlines_path = "/root/workspace/sr_coref/src/benchmarking/data/radcoref_input.jsonlines"
pred_jsonlines_path = "/root/workspace/sr_coref/src/benchmarking/data/radcoref_pred.jsonlines"

In [31]:
def extract_custom_conll(row):
    obj = re.match(r".+\t\d+\t\d+\t(.*?)(\t_){8}(\t(.+))?", row)
    token_str = obj.group(1)
    token_coref_ids = obj.group(4)
    return token_str, token_coref_ids

def extract_onto_conll(row):
    str_list = re.split(r" +", row)
    token_str = str_list[3]
    token_coref_ids = str_list[-1] if str_list[-1] != "-" else None
    return token_str, token_coref_ids

row_info_extractor = extract_custom_conll

# Load radcoref test conll

In [32]:
import re

with open(test_conll_path, "r", encoding="utf-8") as f:
    rows = f.readlines()
    rows = [i.strip("\n") for i in rows]

In [33]:
class ConllDocument:
    def __init__(self, doc_key):
        self.doc_key = doc_key
        self.sent_toks = []
        self.sent_tok_idx = []
        self.gt_clusters = []  # [[start,end], ...]
        self.pred_clusters = []

        self._new_sent = True
        self._tok_pointer = 0

    def add_token(self, token):
        if self._new_sent:
            self.sent_toks.append([])
            self.sent_tok_idx.append([])
            self._new_sent = False
        self.sent_toks[-1].append(token)
        self.sent_tok_idx[-1].append(self._tok_pointer)
        self._tok_pointer += 1
        return self._tok_pointer - 1
    
    def add_gt_cluster(self, token_coref_id, span_start, span_end):
        while len(self.gt_clusters) < (token_coref_id + 1):
            self.gt_clusters.append([])
        if span_start is not None:
            self.gt_clusters[token_coref_id].append([span_start, span_end])
        elif span_start == None:
            last_none_ele = next(filter(lambda x: x[1] is None, reversed(self.gt_clusters[token_coref_id])), None)
            assert last_none_ele is not None
            last_none_ele[1] = span_end
        else:
            raise RuntimeError("Should not see this.")
    
    def add_pred_cluster(self, coref_id, span_start, span_end):
        while len(self.pred_clusters) < (coref_id + 1):
            self.pred_clusters.append([])
        self.pred_clusters[coref_id].append([span_start, span_end])
        
    def __repr__(self):
        return f"{self.doc_key}: {self.gt_clusters}"

In [34]:
current_doc_obj = None
doc_objs = []
for row in rows:
    if row == "" and current_doc_obj == None:
        continue

    if row.startswith("#begin"):
        obj = re.match(r"#begin document \((.+)\); part 0", row)
        dockey = obj.group(1)
        current_doc_obj = ConllDocument(dockey)
    elif row == "#end document":
        doc_objs.append(current_doc_obj)
        current_doc_obj = None
    else:
        assert current_doc_obj != None
        
        # next sentence identifier
        if row == "":
            current_doc_obj._new_sent = True
            continue

        token_str, token_coref_ids= row_info_extractor(row)

        # extracted token str
        tok_idx = current_doc_obj.add_token(token_str)

        # identify the coref cluster to which the token belongs
        if token_coref_ids:
            token_coref_id_list = token_coref_ids.split("|")
            for token_coref_id_str in token_coref_id_list:
                token_coref_id = int(token_coref_id_str.strip("()"))
                span_start = tok_idx if token_coref_id_str.startswith("(") else None
                span_end = tok_idx if token_coref_id_str.endswith(")") else None
                current_doc_obj.add_gt_cluster(token_coref_id, span_start, span_end)

# Build input for model

See https://github.com/kareldo/wl-coref for more details

To predict coreference relations on an arbitrary text, you will need to prepare the data in the jsonlines format (one json-formatted document per line). The following fields are requred:
{
    "document_id": "tc_mydoc_001",
    "cased_words": ["Hi", "!", "Bye", "."],
    "sent_id": [0, 0, 1, 1]
}

document_id can be any string that starts with a two-letter genre identifier. The genres recognized are the following:

bc: broadcast conversation
bn: broadcast news
mz: magazine genre (Sinorama magazine)
nw: newswire genre
pt: pivot text (The Bible)
tc: telephone conversation (CallHome corpus)
wb: web data

In [35]:
import json

In [36]:
for doc_obj in doc_objs:
    doc_id = f"wb_{doc_obj.doc_key}"
    cased_words = []
    sentences_map = []
    for sent_id, sent in enumerate(doc_obj.sent_toks):
        for tok in sent:
            cased_words.append(tok)
            sentences_map.append(sent_id)
    out = {
        "document_id": doc_id,
        "cased_words": cased_words,
        "sent_id": sentences_map,
    }
    with open(input_jsonlines_path, "a", encoding="utf-8") as f:
        f.write(json.dumps(out))
        f.write("\n")

# Using caw-coref

In [47]:
!python \
    /root/workspace/sr_coref/src/benchmarking/wl-coref/predict.py \
    roberta \
    $input_jsonlines_path \
    $pred_jsonlines_path \
    --weights /root/autodl-tmp/hg_offline_models/caw-coref/roberta_release.pt \
    --config-file /root/workspace/sr_coref/src/benchmarking/wl-coref/config.toml

Loading /root/autodl-tmp/hg_offline_models/roberta-large...
Using tokenizer kwargs: {'add_prefix_space': True}


Some weights of the model checkpoint at /root/autodl-tmp/hg_offline_models/roberta-large were not used when initializing RobertaModel: ['lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Bert successfully loaded.
Loading from /root/autodl-tmp/hg_offline_models/caw-coref/roberta_release.pt...
Loaded bert
Loaded we
Loaded rough_scorer
Loaded pw
Loaded a_scorer
Loaded sp
100%|███████████████████████████████████████| 200/200 [00:08<00:00, 24.49docs

# Output pred conll file

In [48]:
class ConllToken(object):
    def __init__(self, docId, sentenceId, tokenId, tokenStr):
        self.docId = docId
        self.sentenceId = sentenceId
        self.tokenId = tokenId
        self.tokenStr = tokenStr
        self.corefLabel = ""

    def add_coref_label(self, label, label_type):
        if label_type == "start":
            label = f"({label}"
        elif label_type == "end":
            label = f"{label})"
        elif label_type == "both":
            label = f"({label})"
            
        if not self.corefLabel:
            self.corefLabel = label
        else:
            self.corefLabel = f"{self.corefLabel}|{label}"

    def get_conll_str(self):
        # IMPORTANT! Any tokens that trigger regex: \((\d+) or (\d+)\) will also
        # trigger "conll/reference-coreference-scorers" unexpectedly,
        # which will either cause execution error or wrong metric score.
        # See coref/wrong_conll_scorer_example for details.
        tok_str = self.tokenStr
        if re.search(r"\(?[^A-Za-z]+\)?", tok_str):
            tok_str = tok_str.replace("(", "[").replace(")", "]")
        if tok_str.strip() == "":
            tok_str = ""
        if self.corefLabel:
            return f"{self.docId}\t0\t{self.tokenId}\t{tok_str}\t" + "_\t" * 8 + self.corefLabel
        return f"{self.docId}\t0\t{self.tokenId}\t{tok_str}\t" + "_\t" * 7 + "_"

    def __str__(self) -> str:
        return f"{self.tokenStr}({self.sentenceId}:{self.tokenId})|[{self.corefLabel}]"

    __repr__ = __str__

In [49]:
with open(pred_jsonlines_path, "r", encoding="utf-8") as f:
    pred_docs = f.readlines()

pred_doc_dict = {}
for pred_doc in pred_docs:
    doc = json.loads(pred_doc)
    doc_key = doc["document_id"].lstrip("wb_")
    pred_doc_dict[doc_key] = doc

In [52]:
for doc_obj in doc_objs:
    BEGIN = f"#begin document ({doc_obj.doc_key}); part 0\n"
    SENTENCE_SEPARATOR = "\n"
    END = "#end document\n"
    
    sentence_list = []
    for sent_id, sent in enumerate(doc_obj.sent_toks):
        token_list = []
        for tok_id, tok in enumerate(sent):
            conll_token = ConllToken(docId=doc_obj.doc_key, 
                                    sentenceId=sent_id,
                                    tokenId=tok_id, 
                                    tokenStr=tok)
            token_list.append(conll_token)
        sentence_list.append(token_list)
        
    conll_tokens = [c_tok for sent in sentence_list for c_tok in sent]
    for coref_id, cluster in enumerate(pred_doc_dict[doc_obj.doc_key]["span_clusters"]):
        for span in cluster:
            start_idx = span[0]
            end_idx = span[1]-1
            if start_idx == end_idx:
                conll_tokens[start_idx].add_coref_label(coref_id, label_type="both")
            else:
                conll_tokens[start_idx].add_coref_label(coref_id, label_type="start")
                conll_tokens[end_idx].add_coref_label(coref_id, label_type="end")
    
    with open(output_file_path, "a", encoding="UTF-8") as out:
        out.write(BEGIN)
        for sent in sentence_list:
            for tok in sent:
                out.write(tok.get_conll_str() + "\n")
            out.write(SENTENCE_SEPARATOR)
        out.write(END)
        out.write(SENTENCE_SEPARATOR)

# Eval

In [53]:
import subprocess
from subprocess import PIPE

In [54]:
def invoke_conll_script(
    scorer_path: str, use_which_metric: str, groundtruth_file_path: str, predicted_file_path: str
):
    """Args:
        scorer_path: The path of the CoNLL scorer script: scorer.pl
        use_which_metric: muc, bclub, ceafe
        groundtruth_file_path: The path of the file serve as a ground truth file
        predicted_file_path: The path of the file serve as a predicted output

    Returns:
        out: The standard output of the script.
        err: The error message if the script is failed. Empty if no error.
    """
    command = [scorer_path, use_which_metric, groundtruth_file_path, predicted_file_path, "none"]

    result = subprocess.run(command, stdout=PIPE, stderr=PIPE)
    out = result.stdout.decode("utf-8")
    err = result.stderr.decode("utf-8")
    if err:
        err += f" Error command: {command}"
    return out, err

def resolve_conll_script_output(output_str):
    """Args:
        output_str: The output of the CoNLL scorer script: scorer.pl. It only support single metric output, i.e. muc, bcub, ceafe, ceafm
    Returns:
        The percentage float value extracted from the script output. The ``%`` symble is omitted.
    """
    regexPattern = r"(\d*\.?\d*)%"
    scores = [float(i) for i in re.findall(regexPattern, output_str)]
    mention_recall = scores[0]
    mention_precision = scores[1]
    mention_f1 = scores[2]
    coref_recall = scores[3]
    coref_precision = scores[4]
    coref_f1 = scores[5]
    return mention_recall, mention_precision, mention_f1, coref_recall, coref_precision, coref_f1

def compute_conll_score(conll_file_gt, conll_file_pred):
    print("gt:", conll_file_gt)
    print("pred:", conll_file_pred)
    overall_f1 = []
    for metric in ['muc', 'bcub', 'ceafe']:
        out, err = invoke_conll_script(scorer_path, metric, conll_file_gt, conll_file_pred)
        mention_recall, mention_precision, mention_f1, coref_recall, coref_precision, coref_f1 = resolve_conll_script_output(out)
        overall_f1.append(coref_f1)
        print(f"Metric: {metric}")
        print(f"mention_recall, mention_precision, mention_f1: {mention_recall}, {mention_precision}, {mention_f1}")
        print(f"coref_recall, coref_precision, coref_f1: {coref_recall}, {coref_precision}, {coref_f1}")

    print(f"Overall F1: {sum(overall_f1) / len(overall_f1)}")

In [55]:
compute_conll_score(conll_file_gt=test_conll_path, conll_file_pred=output_file_path)

gt: /root/workspace/sr_coref/src/benchmarking/data/radcoref_test.conll
pred: /root/workspace/sr_coref/src/benchmarking/data/radcoref_test_pred.conll
Metric: muc
mention_recall, mention_precision, mention_f1: 66.21, 82.94, 73.64
coref_recall, coref_precision, coref_f1: 59.19, 74.8, 66.08
Metric: bcub
mention_recall, mention_precision, mention_f1: 66.21, 82.94, 73.64
coref_recall, coref_precision, coref_f1: 61.72, 77.73, 68.81
Metric: ceafe
mention_recall, mention_precision, mention_f1: 66.21, 82.94, 73.64
coref_recall, coref_precision, coref_f1: 65.51, 81.22, 72.52
Overall F1: 69.13666666666666
