# Resolve the MIMIC-CXR manual annotation

Output the conll formatted files

In [59]:
INCLUDE_SINGLETON = False
brat_source_dir = "../../output/brat_annotation/round4_500_1234r3"
output_base_dir = "../../output/mimic_cxr/manual_training_set/round4_500_1234r3"

## Preparing

In [60]:
import sys
sys.path.append("../../src")

import os
from tqdm import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

from IPython.display import display, HTML
# display(HTML(df.to_html()))

from concurrent.futures import ProcessPoolExecutor, as_completed
from multiprocessing import Event
from common_utils.data_loader_utils import load_mimic_cxr_bySection
from common_utils.coref_utils import resolve_mention_and_group_num
from common_utils.file_checker import FileChecker
from common_utils.common_utils import check_and_create_dirs, check_and_remove_dirs

FILE_CHECKER = FileChecker()
START_EVENT = Event()

mpl.style.use("default")

SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 16

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

In [61]:
from hydra import compose, initialize
from omegaconf import OmegaConf

config = None
with initialize(version_base=None, config_path="../config", job_name="nlp_ensemble"):
        config = compose(config_name="coreference_resolution", overrides=["+nlp_ensemble@_global_=mimic_cxr"])

In [62]:
check_and_remove_dirs(output_base_dir, True)
check_and_create_dirs(output_base_dir)

## Read from BRAT output

In [63]:
def find_sub_list(sublist,source_list) -> tuple[int,int]:
    """ Returns: start index, end index (inclusive) """
    sll=len(sublist)
    for ind in (i for i,e in enumerate(source_list) if e==sublist[0]):
        if source_list[ind:ind+sll]==sublist:
            return ind,ind+sll-1

# greeting = ["\n",""]
# print(find_sub_list([""], greeting))

In [64]:
class BratMention:
    def __init__(self,id,start,end, mention_str) -> None:
        self.id = id
        self.tok_start = start
        self.tok_end = end # Not inclusive
        self.group_id_list = [] # Will only have one element, as brat ann scheme not allow to assign one mention to multi coref cluster (for now)
        self.mention_str = mention_str
    
    def __eq__(self, __o: object) -> bool:
        if isinstance(__o, BratMention):
            return self.id == __o.id
        else:
            return self.id == __o

    def __hash__(self) -> int:
        return self.id.__hash__()
    
    def __repr__(self) -> str:
        return self.__str__()

    def __str__(self) -> str:
        return f"{self.id}({self.tok_start},{self.tok_end})"

class BratCorefGroup:
    def __init__(self) -> None:
        self.coref_group:list[set[BratMention]] = []


    def add(self, ment_a:BratMention, ment_b:BratMention):
        to_be_merged = set()
        for group_id, mention_set in enumerate(self.coref_group):
            if ment_a in mention_set or ment_b in mention_set:
                to_be_merged.add(group_id)
                mention_set.update([ment_a,ment_b])
        if len(to_be_merged) == 0:
            # Not exist in curr group, thus create new group
            self.coref_group.append({ment_a,ment_b})
        elif len(to_be_merged) > 1:
            # Exist in multiple groups, thus need to merge
            new_group = set()
            to_be_removed = list(to_be_merged.copy())
            while to_be_merged:
                new_group = new_group.union(self.coref_group[to_be_merged.pop()])
            for index in sorted(to_be_removed, reverse=True):
                del self.coref_group[index]
            self.coref_group.append(new_group)

    def __str__(self) -> str:
        out = []
        for group in self.coref_group:
            out.append(",".join(map(str, group)))
        return "|".join(out)

In [65]:
import ast

for section_name in ["findings","impression"]:
    brat_dir = os.path.join(brat_source_dir, section_name)
    spacy_dir = os.path.join("../../output/mimic_cxr/nlp_ensemble/spacy",section_name)
    for doc_id in [f.rstrip(".txt") for f in FILE_CHECKER.filter(os.listdir(brat_dir)) if ".txt" in f]:
        # if doc_id != "s55536649" or section_name != "impression":
        #     continue
        
        # brat outputs
        with open(os.path.join(brat_dir, doc_id+".txt"), "r", encoding="UTF-8") as f:
            txt_file_str = "".join(f.readlines())
        with open(os.path.join(brat_dir, doc_id+".ann"), "r", encoding="UTF-8") as f:
            ann_file_list = f.readlines()

        # The source of the brat txt files.
        df_spacy = pd.read_csv(os.path.join(spacy_dir, f"{doc_id}.csv"), index_col=0, na_filter=False)
        # Sometime a token is whitespace, which would make the split() not work as expecting. Thus we use other symbol
        df_sentence = df_spacy.groupby(['[sp]sentence_group'])['[sp]token'].apply('#@#'.join).reset_index()
        sentences_withoutstrip = [str(_series.get("[sp]token")) for _, _series in df_sentence.iterrows()]
        sentence_tok_id = [arr for key, arr in df_spacy.groupby(['[sp]sentence_group']).indices.items()]
        
        # Align to spacy. Also read the brat token offset.
        idx = 0
        brat_offset = 0
        df_brat = pd.DataFrame(columns=["brat_tok","spacy_index","brat_offset"])
        for sent_id, (sentence_str, id_list) in enumerate(zip(sentences_withoutstrip, sentence_tok_id)):
            tok_list_spacy = sentence_str.split("#@#")
            tok_list_brat = sentence_str.strip().strip("#@#").strip("").split("#@#") # When generating brat txt, whitespaces are stripped.

            if len(tok_list_brat) == 1 and tok_list_brat[0] == "":
                start, end = 0, 0
            else:
                start, end = find_sub_list(tok_list_brat, tok_list_spacy)
            
            prev_brat_tok = ""
            for brat_tok, spacy_idx in zip(tok_list_brat, id_list[start:end+1]):
                brat_offset = brat_offset + len(prev_brat_tok) + txt_file_str[brat_offset+len(prev_brat_tok):].index(brat_tok)
                if not df_brat.empty and brat_offset == df_brat.iloc[-1]["brat_offset"]:
                    brat_offset += 1 # In case brat_tok is "". And doing this do not affect the next tok
                df_brat.loc[idx] = (brat_tok, spacy_idx, brat_offset)
                idx+=1
                prev_brat_tok = brat_tok
                
        df_aligned = df_spacy.merge(df_brat, how="outer", left_index=True, right_on="spacy_index").reset_index().drop(columns=["index"])
        df_aligned = df_aligned.loc[:,["[sp]token","[sp]token_offset","[sp]sentence_group","brat_tok","spacy_index", "brat_offset"]]
        df_aligned["[gt]coref_group"] = [-1] * len(df_aligned)
        df_aligned["[gt]coref_group_conll"] = [-1] * len(df_aligned)

        # Resolve brat files
        mention_list:list[BratMention] = []
        brat_coref_obj = BratCorefGroup()
        for line in [line.strip() for line in ann_file_list]:
            line_info_list = line.split("\t")
            # print(line_info_list)
            if line[0] == "T":
                # Mention
                mention_id = line_info_list[0]
                ment_start = line_info_list[1].split(" ")[1]
                ment_end = line_info_list[1].split(" ")[-1]
                mention_str = line_info_list[2]
                mention_list.append(BratMention(mention_id,ment_start,ment_end, mention_str))
            elif line[0] == "R":
                # relation
                relation_id = line_info_list[0]
                mention_a_id = line_info_list[1].split(" ")[1].split(":")[-1]
                mention_b_id = line_info_list[1].split(" ")[2].split(":")[-1]
                mention_a = mention_list[mention_list.index(mention_a_id)]
                mention_b = mention_list[mention_list.index(mention_b_id)]
                brat_coref_obj.add(mention_a, mention_b)

        # Assign coref group id to mentions
        for coref_id, coref_group in enumerate(brat_coref_obj.coref_group):
            for mention in coref_group:
                mention.group_id_list.append(coref_id)
           
        # Assign coref group id to singleton mention
        all_coreferent_mention = [mention for coref_group in brat_coref_obj.coref_group for mention in coref_group]
        next_coref_id = len(brat_coref_obj.coref_group)
        for mention in mention_list:
            if mention not in all_coreferent_mention:
                mention.group_id_list.append(next_coref_id)
                next_coref_id += 1

        # Put conll labels into df
        if INCLUDE_SINGLETON:
            source_mention_list = mention_list
        else:
            source_mention_list = all_coreferent_mention
            
        for mention in source_mention_list:
            row_condition = (df_aligned['brat_offset'] >= int(mention.tok_start)) & (df_aligned['brat_offset'] < int(mention.tok_end))
            # df_aligned.loc[row_condition, "[gt]coref_group"] = int(coref_id)
            target_rows = df_aligned.loc[row_condition, "[gt]coref_group"]
            mention_str = ""
            if len(target_rows) == 1: # mention has only one token
                target_idx = df_aligned.loc[row_condition].iloc[0].name
                if df_aligned.loc[target_idx,"[gt]coref_group"] == -1:
                    df_aligned.loc[target_idx,"[gt]coref_group"] = str(mention.group_id_list)
                    df_aligned.loc[target_idx, "[gt]coref_group_conll"] = str([f"({coref_id})" for coref_id in mention.group_id_list])
                else:
                    # Append new element to exiting list
                    group_id_list = ast.literal_eval(df_aligned.loc[target_idx,"[gt]coref_group"])
                    group_id_list.extend(mention.group_id_list)
                    df_aligned.loc[target_idx,"[gt]coref_group"] = str(list(set(group_id_list)))
                    
                    group_conll_str_list = ast.literal_eval(df_aligned.loc[target_idx,"[gt]coref_group_conll"])
                    group_conll_str_list.extend([f"({coref_id})" for coref_id in mention.group_id_list])
                    df_aligned.loc[target_idx, "[gt]coref_group_conll"] = str(group_conll_str_list)
                    
                mention_str = " ".join(df_aligned.loc[row_condition].get("[sp]token").to_list())
            elif len(target_rows) > 1: # mention has more than one token
                # coref_group
                for index, row_series in df_aligned.loc[row_condition].iterrows():
                    if row_series.loc["[gt]coref_group"] == -1:
                        df_aligned.loc[index,"[gt]coref_group"] = str(mention.group_id_list)
                    else:
                        # Append new element to exiting list
                        group_id_list = ast.literal_eval(row_series.loc["[gt]coref_group"])
                        group_id_list.extend(mention.group_id_list)
                        df_aligned.loc[index,"[gt]coref_group"] = str(list(set(group_id_list)))
                
                # coref_group_conll   
                first_idx = df_aligned.loc[row_condition].iloc[0].name
                last_idx = df_aligned.loc[row_condition].iloc[-1].name
                
                if df_aligned.loc[first_idx, "[gt]coref_group_conll"] == -1:
                    df_aligned.loc[first_idx, "[gt]coref_group_conll"] = str([f"({coref_id}" for coref_id in mention.group_id_list])
                else:
                    group_conll_str_list = ast.literal_eval(df_aligned.loc[first_idx, "[gt]coref_group_conll"])
                    group_conll_str_list.extend([f"({coref_id}" for coref_id in mention.group_id_list])
                    df_aligned.loc[first_idx, "[gt]coref_group_conll"] = str(group_conll_str_list)
                
                if df_aligned.loc[last_idx, "[gt]coref_group_conll"] == -1:
                    df_aligned.loc[last_idx, "[gt]coref_group_conll"] = str([f"{coref_id})" for coref_id in mention.group_id_list])
                else:
                    group_conll_str_list = ast.literal_eval(df_aligned.loc[last_idx, "[gt]coref_group_conll"])
                    group_conll_str_list.extend([f"{coref_id})" for coref_id in mention.group_id_list])
                    df_aligned.loc[last_idx, "[gt]coref_group_conll"] = str(group_conll_str_list)
                
                mention_str = " ".join(df_aligned.loc[first_idx:last_idx].get("[sp]token").to_list())
                        
            try:
                assert mention_str == mention.mention_str
            except AssertionError as err:
                print(f"AssertionError warning: doc_id: {doc_id}, brat label: [{mention.mention_str}], spacy token: [{mention_str}]")
                # raise err
        # display(HTML(df_aligned.to_html()))

        # Write CSV files
        output_dir = os.path.join(output_base_dir,section_name)
        check_and_create_dirs(output_dir)
        df_out = df_aligned.loc[:,["[sp]token","[sp]sentence_group","[gt]coref_group","[gt]coref_group_conll"]]
        # display(HTML(df_out.to_html()))
        df_out.to_csv(os.path.join(output_dir, f"{doc_id}.csv"))
        
    print("Output: ",output_dir)

Output:  ../../output/mimic_cxr/manual_training_set/round4_500_1234r3/findings
Output:  ../../output/mimic_cxr/manual_training_set/round4_500_1234r3/impression
