In [5]:
scorer_path = "/root/workspace/fast-coref/coref_resources/reference-coreference-scorers/scorer.pl"

test_conll_path = "/root/workspace/sr_coref/src/benchmarking/data/radcoref_test.conll"
test_jsonlines_path = "/root/workspace/sr_coref/src/benchmarking/data/radcoref_test.jsonlines"

output_file_path = "/root/workspace/sr_coref/src/benchmarking/data/radcoref_pred_test.conll"

corefqa_repo_path = "/root/workspace/sr_coref/src/benchmarking/CorefQA/"

In [6]:
import sys

sys.path.append(corefqa_repo_path)

# Convert radcoref test conll to jsonlines for CorefQA

In jsonlines, the `clusters` are represented as subtok_idx. To get the doctok_idx, we should map by using `subtoken_map`

In [7]:
from __future__ import absolute_import, division, print_function

import collections
import json
import os
import re
import sys

import conll
import util
from bert import tokenization

vocab_file_path = os.path.join(corefqa_repo_path, "cased_config_vocab", "vocab.txt")
tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file_path, do_lower_case=False)

In [51]:
class DocumentState(object):
    def __init__(self, key):
        self.doc_key = key
        self.sentence_end = []
        self.token_end = []
        self.tokens = []
        self.subtokens = []
        self.info = []
        self.segments = []
        self.subtoken_map = []
        self.segment_subtoken_map = []
        self.sentence_map = []
        self.pronouns = []
        self.clusters = collections.defaultdict(list)
        self.coref_stacks = collections.defaultdict(list)
        self.speakers = []
        self.segment_info = []
        
    def __repr__(self):
        return f"tokens: {self.tokens}"
    
    def finalize(self):
        # finalized: segments, segment_subtoken_map
        # populate speakers from info
        subtoken_idx = 0
        for segment in self.segment_info:
            speakers = []
            for i, tok_info in enumerate(segment):
                if tok_info is None and (i == 0 or i == len(segment) - 1):
                    speakers.append("[SPL]")
                elif tok_info is None:
                    speakers.append(speakers[-1])
                else:
                    speakers.append(tok_info[9])
                    if tok_info[4] == "PRP":
                        self.pronouns.append(subtoken_idx)
                subtoken_idx += 1
            self.speakers += [speakers]
        # populate sentence map

        # populate clusters
        first_subtoken_index = -1
        for seg_idx, segment in enumerate(self.segment_info):
            speakers = []
            for i, tok_info in enumerate(segment):
                first_subtoken_index += 1
                coref = tok_info[-2] if tok_info is not None else "-"
                if coref != "-":
                    last_subtoken_index = first_subtoken_index + \
                        tok_info[-1] - 1
                    for part in coref.split("|"):
                        if part[0] == "(":
                            if part[-1] == ")":
                                cluster_id = int(part[1:-1])
                                self.clusters[cluster_id].append(
                                    (first_subtoken_index, last_subtoken_index))
                            else:
                                cluster_id = int(part[1:])
                                self.coref_stacks[cluster_id].append(
                                    first_subtoken_index)
                        else:
                            cluster_id = int(part[:-1])
                            start = self.coref_stacks[cluster_id].pop()
                            self.clusters[cluster_id].append(
                                (start, last_subtoken_index))
        # merge clusters
        merged_clusters = []
        for c1 in self.clusters.values():
            existing = None
            for m in c1:
                for c2 in merged_clusters:
                    if m in c2:
                        existing = c2
                        break
                if existing is not None:
                    break
            if existing is not None:
                print("Merging clusters (shouldn't happen very often.)")
                existing.update(c1)
            else:
                merged_clusters.append(set(c1))
        merged_clusters = [list(c) for c in merged_clusters]
        all_mentions = util.flatten(merged_clusters)
        sentence_map = get_sentence_map(self.segments, self.sentence_end)
        subtoken_map = util.flatten(self.segment_subtoken_map)
        assert len(all_mentions) == len(set(all_mentions))
        num_words = len(util.flatten(self.segments))
        assert num_words == len(util.flatten(self.speakers))
        assert num_words == len(subtoken_map), (num_words, len(subtoken_map))
        assert num_words == len(sentence_map), (num_words, len(sentence_map))
        return {"doc_key": self.doc_key, "sentences": self.segments, "speakers": self.speakers, "constituents": [], "ner": [], "clusters": merged_clusters, "sentence_map": sentence_map, "subtoken_map": subtoken_map, "pronouns": self.pronouns}
    
def get_sentence_map(segments, sentence_end):
    current = 0
    sent_map = []
    sent_end_idx = 0
    assert len(sentence_end) == sum([len(s) - 2 for s in segments])
    for segment in segments:
        sent_map.append(current)
        for i in range(len(segment) - 2):
            sent_map.append(current)
            current += int(sentence_end[sent_end_idx])
            sent_end_idx += 1
        sent_map.append(current)
    return sent_map


def normalize_word(word):
    if word == "/." or word == "/?":
        return word[1:]
    else:
        return word

# first try to satisfy constraints1, and if not possible, constraints2.
def split_into_segments(document_state, max_segment_len, constraints1, constraints2):
    current = 0
    previous_token = 0
    while current < len(document_state.subtokens):
        end = min(current + max_segment_len - 1 - 2, len(document_state.subtokens) - 1)
        while end >= current and not constraints1[end]:
            end -= 1
        if end < current:
            end = min(current + max_segment_len - 1 - 2, len(document_state.subtokens) - 1)
            while end >= current and not constraints2[end]:
                end -= 1
            if end < current:
                raise Exception("Can't find valid segment")
        document_state.segments.append(["[CLS]"] + document_state.subtokens[current : end + 1] + ["[SEP]"])
        subtoken_map = document_state.subtoken_map[current : end + 1]
        document_state.segment_subtoken_map.append([previous_token] + subtoken_map + [subtoken_map[-1]])
        info = document_state.info[current : end + 1]
        document_state.segment_info.append([None] + info + [None])
        current = end + 1
        previous_token = subtoken_map[-1]

def get_document(document_lines, tokenizer, segment_len):
    document_state = DocumentState(document_lines[0])
    word_idx = -1
    for line in document_lines[1]:
        row = line.split()
        if len(row) == 12:
            row.append("-") # follow the same structure as ontonotes conll files ()
        sentence_end = len(row) == 0
        if not sentence_end:
            assert len(row) == 13
            word_idx += 1
            word = normalize_word(row[3])
            subtokens = tokenizer.tokenize(word)
            document_state.tokens.append(word)
            document_state.token_end += ([False] * (len(subtokens) - 1)) + [True]
            for sidx, subtoken in enumerate(subtokens):
                document_state.subtokens.append(subtoken)
                info = None if sidx != 0 else (row + [len(subtokens)])
                document_state.info.append(info)
                document_state.sentence_end.append(False)
                document_state.subtoken_map.append(word_idx)
        else:
            document_state.sentence_end[-1] = True
    split_into_segments(document_state, segment_len, document_state.sentence_end, document_state.token_end)
    document = document_state.finalize()
    return document


def minimize_partition(tokenizer, input_path, output_path, seg_len=512):
    count = 0
    print("Minimizing {}".format(input_path))
    documents = []
    with open(input_path, "r") as input_file:
        for line in input_file.readlines():
            begin_document_match = re.match(conll.BEGIN_DOCUMENT_REGEX, line)
            if begin_document_match:
                doc_key = conll.get_doc_key(begin_document_match.group(1), begin_document_match.group(2))
                documents.append((doc_key, []))
            elif line.startswith("#end document"):
                continue
            else:
                documents[-1][1].append(line)
    with open(output_path, "w") as output_file:
        for document_lines in documents:
            document = get_document(document_lines, tokenizer, seg_len)
            output_file.write(json.dumps(document))
            output_file.write("\n")
            count += 1
    print("Wrote {} documents to {}".format(count, output_path))

In [52]:
minimize_partition(tokenizer, input_path=test_conll_path, output_path=test_jsonlines_path)

Minimizing /root/workspace/sr_coref/src/benchmarking/data/radcoref_test.conll


Wrote 200 documents to /root/workspace/sr_coref/src/benchmarking/data/radcoref_test.jsonlines


In [53]:
sents = [["[CLS]", "As", "compared", "to", "the", "previous", "image", ",", "the", "alignment", "of", "the", "stern", "##al", "wires", "is", "unchanged", ".", "Un", "##chang", "##ed", "position", "of", "the", "right", "internal", "j", "##ug", "##ular", "vein", "cat", "##he", "##ter", ",", "with", "its", "tip", "projecting", "over", "the", "mid", "to", "lower", "SV", "##C", ".", "No", "p", "##ne", "##um", "##oth", "##orax", ".", "Small", "bilateral", "p", "##le", "##ural", "e", "##ff", "##usions", "are", "better", "appreciated", "on", "the", "lateral", "than", "on", "the", "frontal", "image", ".", "There", "are", "limited", "to", "the", "cost", "##op", "##hren", "##ic", "sin", "##uses", ".", "No", "pulmonary", "ed", "##ema", ".", "No", "pneumonia", ".", "[SEP]"]]

subtok_map = [0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 11, 12, 13, 14, 15, 16, 16, 16, 17, 18, 19, 20, 21, 22, 22, 22, 23, 24, 24, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 35, 36, 37, 38, 38, 38, 38, 38, 39, 40, 41, 42, 42, 42, 43, 43, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 61, 61, 61, 62, 62, 63, 64, 65, 66, 66, 67, 68, 69, 70, 70]

In [63]:
sent = [i for sent in sents for i in sent]

subtok_map[35:35+1]

[27]

In [62]:
doc_objs[1]

s52428114_impression: [[[19, 24], [27, 27]]]

# Load radcoref test conll

In [59]:
import re

with open(test_conll_path, "r", encoding="utf-8") as f:
    rows = f.readlines()
    rows = [i.strip("\n") for i in rows]

In [None]:
def extract_onto_conll(row):
    str_list = re.split(r" +", row)
    token_str = str_list[3]
    token_coref_ids = str_list[-1] if str_list[-1] != "-" else None
    return token_str, token_coref_ids

row_info_extractor = extract_custom_conll

In [60]:
class ConllDocument:
    def __init__(self, doc_key):
        self.doc_key = doc_key
        self.sent_toks = []
        self.sent_tok_idx = []
        self.gt_clusters = []  # [[start,end], ...]
        self.pred_clusters = []

        self._new_sent = True
        self._tok_pointer = 0

    def add_token(self, token):
        if self._new_sent:
            self.sent_toks.append([])
            self.sent_tok_idx.append([])
            self._new_sent = False
        self.sent_toks[-1].append(token)
        self.sent_tok_idx[-1].append(self._tok_pointer)
        self._tok_pointer += 1
        return self._tok_pointer - 1
    
    def add_gt_cluster(self, token_coref_id, span_start, span_end):
        while len(self.gt_clusters) < (token_coref_id + 1):
            self.gt_clusters.append([])
        if span_start is not None:
            self.gt_clusters[token_coref_id].append([span_start, span_end])
        elif span_start == None:
            last_none_ele = next(filter(lambda x: x[1] is None, reversed(self.gt_clusters[token_coref_id])), None)
            assert last_none_ele is not None
            last_none_ele[1] = span_end
        else:
            raise RuntimeError("Should not see this.")
    
    def add_pred_cluster(self, coref_id, span_start, span_end):
        while len(self.pred_clusters) < (coref_id + 1):
            self.pred_clusters.append([])
        self.pred_clusters[coref_id].append([span_start, span_end])
        
    def __repr__(self):
        return f"{self.doc_key}: {self.gt_clusters}"

In [61]:
current_doc_obj = None
doc_objs = []
for row in rows:
    if row == "" and current_doc_obj == None:
        continue

    if row.startswith("#begin"):
        obj = re.match(r"#begin document \((.+)\); part 0", row)
        dockey = obj.group(1)
        current_doc_obj = ConllDocument(dockey)
    elif row == "#end document":
        doc_objs.append(current_doc_obj)
        current_doc_obj = None
    else:
        assert current_doc_obj != None
        
        # next sentence identifier
        if row == "":
            current_doc_obj._new_sent = True
            continue

        token_str, token_coref_ids= row_info_extractor(row)

        # extracted token str
        tok_idx = current_doc_obj.add_token(token_str)

        # identify the coref cluster to which the token belongs
        if token_coref_ids:
            token_coref_id_list = token_coref_ids.split("|")
            for token_coref_id_str in token_coref_id_list:
                token_coref_id = int(token_coref_id_str.strip("()"))
                span_start = tok_idx if token_coref_id_str.startswith("(") else None
                span_end = tok_idx if token_coref_id_str.endswith(")") else None
                current_doc_obj.add_gt_cluster(token_coref_id, span_start, span_end)

# Predict

In [None]:
GPU=0 python /root/workspace/sr_coref/src/benchmarking/CorefQA/predict.py spanbert_large /root/workspace/sr_coref/src/benchmarking/data/radcoref.txt /root/workspace/sr_coref/src/benchmarking/data/radcoref_out.json