# Getting ready for manual annotation on MIMIC-CXR dataset

There are 227835 reports in the dataset, each report has multiple section, such as findings and imporession.

1. Balanced sampling:  The anatomy and observation could be quite different. We don't want to sample on the same anatomy and observation.
   1. To do so, we can statistic on which anatomy appears how many times on which report.
2. 

## Preparing

In [138]:
import sys
sys.path.append("../../src")
sys.path.append("../../../../git_clone_repos/fast-coref/src")

import os
import ast
from tqdm import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from IPython.display import display, HTML
from concurrent.futures import ProcessPoolExecutor, as_completed
from multiprocessing import Event
from common_utils.data_loader_utils import load_mimic_cxr_bySection
from common_utils.coref_utils import resolve_mention_and_group_num
from common_utils.file_checker import FileChecker
from common_utils.common_utils import check_and_create_dirs

FILE_CHECKER = FileChecker()
START_EVENT = Event()

mpl.style.use("default")

SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 16

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

In [2]:
from hydra import compose, initialize
from omegaconf import OmegaConf

config = None
with initialize(version_base=None, config_path="../config", job_name="nlp_ensemble"):
        config = compose(config_name="data_preprocessing", overrides=["+nlp_ensemble@_global_=mimic_cxr"])
        section_name_cfg = config.name_style.mimic_cxr.section_name
        output_section_cfg = config.output.section
        input_path = config.input.path
        data_size, pid_list, sid_list, section_list = load_mimic_cxr_bySection(input_path, output_section_cfg, section_name_cfg)

# Sort
s_list, f_list, i_list, pfi_list, fai_list = zip(*sorted(zip(sid_list, section_list[0][1], section_list[1][1], section_list[2][1], section_list[3][1])))
sid_list = s_list
section_list = [
        ("findings", f_list),
        ("impression", i_list),
        ("provisional_findings_impression", pfi_list),
        ("findings_and_impression",fai_list)
]

## Analyse

### Word frequency

In [None]:
from collections import Counter

def batch_processing(sid, input_path):
    START_EVENT.wait()
    df_spacy = pd.read_csv(input_path, index_col=0, na_filter=False)

    word_counter = Counter()
    word_counter.update(df_spacy.loc[:,"[sp]token"].to_list())

    return sid, word_counter


for section_entry in os.scandir("../../output/mimic_cxr/nlp_ensemble/spacy"):
    if section_entry.is_dir():
        total_word_counter = Counter()
        
        tasks = []
        with ProcessPoolExecutor(max_workers=14) as executor:
            for report_entry in tqdm(os.scandir(section_entry.path)):
                if FILE_CHECKER.ignore(os.path.abspath(report_entry.path)):
                    continue
                sid = report_entry.name.rstrip(".csv")
                tasks.append(executor.submit(batch_processing, sid, report_entry.path))

            START_EVENT.set()

            # Receive results from multiprocessing.
            for future in tqdm(as_completed(tasks), total=len(tasks)):
                sid, word_counter = future.result()
                total_word_counter.update(word_counter)

            START_EVENT.clear()
        
        print("Section:", section_entry.name)
        print(total_word_counter.most_common(100))
        print("-" * 80)

### The scatter plot of "number of tokens" and "number of coreference" per document

In [None]:
from collections import Counter

scoref_dir = "../../output/mimic_cxr/nlp_ensemble/corenlp/scoref"
dcoref_dir = "../../output/mimic_cxr/nlp_ensemble/corenlp/dcoref"
fcoref_dir = "../../output/mimic_cxr/nlp_ensemble/fast_coref_joint"

def batch_processing(section_name, sid, spacy_input_path):
    START_EVENT.wait()
    df_spacy = pd.read_csv(spacy_input_path, index_col=0, na_filter=False)
    df_scoref = pd.read_csv(os.path.join(scoref_dir,section_name,sid+".csv"), index_col=0, na_filter=False)
    df_dcoref = pd.read_csv(os.path.join(dcoref_dir,section_name,sid+".csv"), index_col=0, na_filter=False)
    df_fcoref = pd.read_csv(os.path.join(fcoref_dir,section_name,sid+".csv"), index_col=0, na_filter=False)

    token_list = df_spacy.loc[:,"[sp]token"].to_list()
    token_num = len(token_list)

    _, scoref_group_num = resolve_mention_and_group_num(df_scoref, "[co][ml]coref_group_conll")
    _, dcoref_group_num = resolve_mention_and_group_num(df_dcoref, "[co][rb]coref_group_conll")
    _, fcoref_group_num = resolve_mention_and_group_num(df_fcoref, "[fj]coref_group_conll")

    return sid, token_num, scoref_group_num, dcoref_group_num, fcoref_group_num

section_doc_numData_dict:dict[str,dict[str,dict[str,int]]] = {}
section_scatter_data_list = {}
for section_entry in os.scandir("../../output/mimic_cxr/nlp_ensemble/spacy"):
    if section_entry.is_dir():
        print("Processing section:", section_entry.name)
        section_doc_numData_dict[section_entry.name]:dict[str,dict[str,int]] = {}

        tasks = []
        scatter_data_list:list[dict] = []
        with ProcessPoolExecutor(max_workers=14) as executor:
            for report_entry in tqdm(os.scandir(section_entry.path)):
                if FILE_CHECKER.ignore(os.path.abspath(report_entry.path)):
                    continue
                sid = report_entry.name.rstrip(".csv")
                tasks.append(executor.submit(batch_processing,section_entry.name, sid, report_entry.path))

            START_EVENT.set()

            # Receive results from multiprocessing.
            for future in tqdm(as_completed(tasks), total=len(tasks)):
                sid, token_num, scoref_group_num, dcoref_group_num, fcoref_group_num = future.result()
                numData = {
                    "tokNum":token_num,
                    "sNum": scoref_group_num,
                    "dNum": dcoref_group_num,
                    "fNum": fcoref_group_num,
                    "avgNum": (scoref_group_num + dcoref_group_num + fcoref_group_num) / 3
                }
                # For later statistic
                section_doc_numData_dict[section_entry.name][sid]:dict[str,int] = numData
                # For scatter plot
                scatter_data_list.append(numData)

            START_EVENT.clear()

        section_scatter_data_list[section_entry.name] = scatter_data_list

In [None]:
for section_name, scatter_data_list in section_scatter_data_list.items():

    scatter_data_list = sorted(scatter_data_list, key=lambda x: x["tokNum"]) # Sort by token num

    fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, sharex=True, sharey=True, figsize=(9,7))
    fig.suptitle(f"Section: {section_name}")
    fig.tight_layout()

    ax1.scatter([data_dict["tokNum"] for data_dict in scatter_data_list], [data_dict["sNum"] for data_dict in scatter_data_list],s=1, alpha=0.5)
    ax2.scatter([data_dict["tokNum"] for data_dict in scatter_data_list], [data_dict["dNum"] for data_dict in scatter_data_list],s=1, alpha=0.5)
    ax3.scatter([data_dict["tokNum"] for data_dict in scatter_data_list], [data_dict["fNum"] for data_dict in scatter_data_list],s=1, alpha=0.5)
    ax4.scatter([data_dict["tokNum"] for data_dict in scatter_data_list], [data_dict["avgNum"] for data_dict in scatter_data_list],s=1, alpha=0.5)

    ax1.set_title("ML-based")
    ax2.set_title("Rule-based")
    ax3.set_title("Neural-based")
    ax4.set_title("Unweighted mean")
    ax2.set_ylabel("Number of coreference", fontdict={"size":14})
    ax4.set_xlabel("Number of tokens", fontdict={"size":14})
    fig.tight_layout()
            

### The Bar charts of "number of coreference on average in the scale of token numbers" for each section

In [None]:
section_tokRange_corefRange_docList_dict:dict[str,dict] = {}
for section_name, doc_numData_dict in section_doc_numData_dict.items():
    section_tokRange_corefRange_docList_dict[section_name] = {
        "0-49toks": {
            "0coref":[], "0-1coref":[], "1coref":[], "1-2coref":[], "2coref":[], ">2coref":[]
        },
        "50-99toks": {
            "0coref":[], "0-1coref":[], "1coref":[], "1-2coref":[], "2coref":[], ">2coref":[]
        },
        "100-149toks": {
            "0coref":[], "0-1coref":[], "1coref":[], "1-2coref":[], "2coref":[], ">2coref":[]
        },
        ">150toks": {
            "0coref":[], "0-1coref":[], "1coref":[], "1-2coref":[], "2coref":[], ">2coref":[]
        },
    }
    tokRange_corefRange_docList_dict = section_tokRange_corefRange_docList_dict[section_name]
    for doc_id, numData_dict in doc_numData_dict.items():
        if 0 <= numData_dict["tokNum"] < 50:
            target_dict = tokRange_corefRange_docList_dict["0-49toks"]
            if numData_dict["avgNum"] == 0:
                target_dict["0coref"].append(doc_id)
            elif 0 < numData_dict["avgNum"] < 1:
                target_dict["0-1coref"].append(doc_id)
            elif numData_dict["avgNum"] == 1:
                target_dict["1coref"].append(doc_id)
            elif 1 < numData_dict["avgNum"] < 2:
                target_dict["1-2coref"].append(doc_id)
            elif numData_dict["avgNum"] == 2:
                target_dict["2coref"].append(doc_id)
            else:
                target_dict[">2coref"].append(doc_id)
        elif 50 <= numData_dict["tokNum"] < 99:
            target_dict = tokRange_corefRange_docList_dict["50-99toks"]
            if numData_dict["avgNum"] == 0:
                target_dict["0coref"].append(doc_id)
            elif 0 < numData_dict["avgNum"] < 1:
                target_dict["0-1coref"].append(doc_id)
            elif numData_dict["avgNum"] == 1:
                target_dict["1coref"].append(doc_id)
            elif 1 < numData_dict["avgNum"] < 2:
                target_dict["1-2coref"].append(doc_id)
            elif numData_dict["avgNum"] == 2:
                target_dict["2coref"].append(doc_id)
            else:
                target_dict[">2coref"].append(doc_id)
        elif 100 <= numData_dict["tokNum"] < 149:
            target_dict = tokRange_corefRange_docList_dict["100-149toks"]
            if numData_dict["avgNum"] == 0:
                target_dict["0coref"].append(doc_id)
            elif 0 < numData_dict["avgNum"] < 1:
                target_dict["0-1coref"].append(doc_id)
            elif numData_dict["avgNum"] == 1:
                target_dict["1coref"].append(doc_id)
            elif 1 < numData_dict["avgNum"] < 2:
                target_dict["1-2coref"].append(doc_id)
            elif numData_dict["avgNum"] == 2:
                target_dict["2coref"].append(doc_id)
            else:
                target_dict[">2coref"].append(doc_id)
        elif numData_dict["tokNum"] >= 150:
            target_dict = tokRange_corefRange_docList_dict[">150toks"]
            if numData_dict["avgNum"] == 0:
                target_dict["0coref"].append(doc_id)
            elif 0 < numData_dict["avgNum"] < 1:
                target_dict["0-1coref"].append(doc_id)
            elif numData_dict["avgNum"] == 1:
                target_dict["1coref"].append(doc_id)
            elif 1 < numData_dict["avgNum"] < 2:
                target_dict["1-2coref"].append(doc_id)
            elif numData_dict["avgNum"] == 2:
                target_dict["2coref"].append(doc_id)
            else:
                target_dict[">2coref"].append(doc_id)


    fig, axs = plt.subplots(4, 1, figsize=(9,7))
    fig.suptitle(f"Section: {section_name}")

    for ax_id, (tokRangeLabel, corefRangeDict) in enumerate(tokRange_corefRange_docList_dict.items()):
        x = corefRangeDict.keys()
        y = [len(_sid_list) for _sid_list in corefRangeDict.values()]
        bar_container = axs[ax_id].bar(x,y)
        axs[ax_id].set_title(tokRangeLabel)
        axs[ax_id].bar_label(bar_container, padding=3)

    axs[1].set_ylabel("Number of notes (reports)")
    axs[3].set_xlabel("Coreference Range")
    
    fig.tight_layout()


### Confusion Matrix

#### Include 0 coref

In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

for section_name, tokRange_corefRange_docList_dict in section_tokRange_corefRange_docList_dict.items():
    data_list:list[list] = []
    for tokRange, corefRange_docList_dict in tokRange_corefRange_docList_dict.items():
        data_row = []
        data_list.append(data_row)
        for corefRange, docList in corefRange_docList_dict.items():
            data_row.append(len(docList))
    data = np.array(data_list)
    percentages = data/data.sum() * 100

    new_row = []
    for col_idx in range(len(percentages[0])):
        new_row.append(percentages[:,col_idx].sum())
    new_percentages = np.append(percentages, [new_row], 0)

    new_col = []
    for row_idx in range(len(new_percentages)):
        new_col.append([new_percentages[row_idx].sum()])
    new_percentages = np.append(new_percentages, new_col, 1)

    x_labels = ["0", "0-1", "1", "1-2", "2", ">2", "all"]
    y_labels = ["0-49","50-99", "100-149", ">150", "all"]

    plt.figure() #this creates a new figure on which your plot will appear
    plt.title(f"Section: {section_name}, total: {data.sum()}")
    plt.tight_layout()

    ax = sns.heatmap(new_percentages, xticklabels=x_labels, yticklabels=y_labels, annot=True, fmt='.2f')
    ax.set_xlabel("Number of coreferences")
    ax.set_ylabel("Number of tokens")

print("Each value shown in a cell is percentage")

#### Exclude 0 coref

In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

for section_name, tokRange_corefRange_docList_dict in section_tokRange_corefRange_docList_dict.items():
    data_list:list[list] = []
    for tokRange, corefRange_docList_dict in tokRange_corefRange_docList_dict.items():
        data_row = []
        data_list.append(data_row)
        for corefRange, docList in corefRange_docList_dict.items():
            if corefRange == "0coref":
                continue
            data_row.append(len(docList))
    data = np.array(data_list)
    percentages = data/data.sum() * 100

    # add overall statistic
    new_row = []
    for col_idx in range(len(percentages[0])):
        new_row.append(percentages[:,col_idx].sum())
    new_percentages = np.append(percentages, [new_row], 0)

    new_col = []
    for row_idx in range(len(new_percentages)):
        new_col.append([new_percentages[row_idx].sum()])
    new_percentages = np.append(new_percentages, new_col, 1)

    x_labels = ["0-1", "1", "1-2", "2", ">2", "all"]
    y_labels = ["0-49","50-99", "100-149", ">150", "all"]

    plt.figure() #this creates a new figure on which your plot will appear
    plt.title(f"Section: {section_name}, total: {data.sum()}")
    plt.tight_layout()

    ax = sns.heatmap(new_percentages, xticklabels=x_labels, yticklabels=y_labels, annot=True, fmt='.2f')
    ax.set_xlabel("Number of coreferences")
    ax.set_ylabel("Number of tokens")

print("Each value shown in a cell is percentage")

In [None]:
print(data)

new_row = []
for col_idx in range(len(data[0])):
    new_row.append(data[:,col_idx].sum())

new_data = np.append(data, [new_row], 0)


new_col = []
for row_idx in range(len(new_data)):
    new_col.append([new_data[row_idx].sum()])

new_data = np.append(new_data, new_col, 1)

print(new_data)


## Sampling with out ann

### Example

In [None]:
doc_id = section_tokRange_corefRange_docList_dict["findings"]["100-149toks"]["1coref"][0]
section_name = "findings"
print("section_name, doc_id:",section_name, doc_id)

df_scoref = pd.read_csv(os.path.join(scoref_dir,section_name,f"{doc_id}.csv"), index_col=0, na_filter=False)
df_dcoref = pd.read_csv(os.path.join(dcoref_dir,section_name,f"{doc_id}.csv"), index_col=0, na_filter=False)
df_fcoref = pd.read_csv(os.path.join(fcoref_dir,section_name,f"{doc_id}.csv"), index_col=0, na_filter=False)


_, scoref_group_num = resolve_mention_and_group_num(df_scoref, "[co][ml]coref_group_conll")
_, dcoref_group_num = resolve_mention_and_group_num(df_dcoref, "[co][rb]coref_group_conll")
_, fcoref_group_num = resolve_mention_and_group_num(df_fcoref, "[fj]coref_group_conll")

print("scoref_group_num, dcoref_group_num, fcoref_group_num:", scoref_group_num, dcoref_group_num, fcoref_group_num)
print(section_list[0][1][sid_list.index(doc_id)])

### Sample

In [None]:
import random
random.seed(42)

sampling_num_dict = {
    "findings": [
        [2,3,2,1,1,1],
        [7,5,5,4,1,1],
        [3,2,2,2,1,1],
        [1,1,1,1,1,1]
    ],
        "impression": [
        [4,3,2,1,1,1],
        [6,5,5,4,1,1],
        [2,2,2,2,1,1],
        [1,1,1,1,1,1]
    ]
}
output_section_doc_dict = {}
for section_name, tokRange_corefRange_docList_dict in section_tokRange_corefRange_docList_dict.items():
    if section_name not in sampling_num_dict:
        continue
    output_section_doc_dict[section_name] = []
    for x_idx,(tokRange, corefRange_docList_dict) in enumerate(tokRange_corefRange_docList_dict.items()):
        for y_idx, (corefRange, docList) in enumerate(corefRange_docList_dict.items()):
            docList = sorted(docList)
            required_num = sampling_num_dict[section_name][x_idx][y_idx]
            sampled_doc = [docList[random.randint(0,len(docList)-1)] for _ in range(required_num)]
            output_section_doc_dict[section_name].extend(sampled_doc)
            print(section_name, tokRange, corefRange, sampled_doc)

### Generate txt

In [None]:
from tqdm import tqdm
from common_utils.common_utils import check_and_remove_dirs

check_and_remove_dirs("../../output/mimic_cxr/statistic/for_brat_annotation", True)
for section_name, _doc_list in output_section_doc_dict.items():
    output_dir = os.path.join("../../output/mimic_cxr/statistic/for_brat_annotation",section_name)
    check_and_create_dirs(output_dir)
    input_dir = os.path.join("../../output/mimic_cxr/nlp_ensemble/spacy",section_name)
    for doc_id in tqdm(_doc_list):
        df_spacy = pd.read_csv(os.path.join(input_dir, f"{doc_id}.csv"), index_col=0, na_filter=False)
        df_sentence = df_spacy.groupby(['[sp]sentence_group'])['[sp]token'].apply(' '.join).reset_index()
        sentences = [str(_series.get("[sp]token")).strip() for _, _series in df_sentence.iterrows()]
        with open(os.path.join(output_dir, f"{doc_id}.txt"), "w", encoding="UTF-8") as f:
            f.write("\n".join(sentences))

## Sampling with ann

Get data distribution

In [3]:
from collections import defaultdict, Counter

from common_utils.coref_utils import get_file_name_prefix


coref_group_conll_colName = "[fj]coref_group_conll"
input_base_dir = "../../output/mimic_cxr/nlp_ensemble/fast_coref_onto_i2b2"

def batch_processing2(file_name, input_file_path):
    START_EVENT.wait()
    doc_id = get_file_name_prefix(input_file_path, ".csv")
    df = pd.read_csv(input_file_path, index_col=0, na_filter=False)
    _, coref_group_num = resolve_mention_and_group_num(df, coref_group_conll_colName, omit_singleton=True)
    return doc_id, coref_group_num

section_doc_coref_counter = defaultdict(Counter)
all_section_corefGroupNum_docId_dict = {}
for section_name in ["findings", "impression"]:
    with ProcessPoolExecutor(max_workers=15) as executor:
        all_task = []
        input_dir = os.path.join(input_base_dir, section_name)
        for file_name in tqdm(FILE_CHECKER.filter(os.listdir(input_dir))):
            input_file_path = os.path.join(input_dir, file_name)
            all_task.append(executor.submit(batch_processing2, file_name, input_file_path))

        # Notify tasks to start
        START_EVENT.set()

        corefGroupNum_docId_dict = defaultdict(list)
        if all_task:
            for future in tqdm(as_completed(all_task), total=len(all_task)):
                doc_id, coref_group_num = future.result()
                section_doc_coref_counter[section_name].update([coref_group_num])
                corefGroupNum_docId_dict[coref_group_num].append(doc_id)
                
        all_section_corefGroupNum_docId_dict[section_name] = corefGroupNum_docId_dict
        executor.shutdown(wait=True, cancel_futures=False)
        START_EVENT.clear()

100%|██████████| 156011/156011 [00:03<00:00, 43962.40it/s]
100%|██████████| 156011/156011 [00:39<00:00, 3917.51it/s]
100%|██████████| 189465/189465 [00:04<00:00, 40970.63it/s]
100%|██████████| 189465/189465 [00:47<00:00, 4004.77it/s]


In [142]:
# A list of gt doc
gt_source_list = [
    "../../output/mimic_cxr/ground_truth",
    "../../output/mimic_cxr/coref_voting/majority_voting_sampling1"
]
gt_section_docId_dict:dict[str,list] = {}
for section_name in ["findings","impression"]:
    for source_dir in gt_source_list:
        gt_all_dir = os.path.join(source_dir, section_name)
        gt_section_docId_dict.setdefault(section_name,[]).extend([i.rstrip(".csv") for i in FILE_CHECKER.filter(os.listdir(gt_all_dir))])
    
# The coref distribution (majority_voting version) of the annotated gt files.
gt_section_corefGroupNum_docId_dict = {}
for section_name in ["findings","impression"]:
    gt_section_corefGroupNum_docId_dict[section_name] = {}
    for coref_num, doc_id_list in all_section_corefGroupNum_docId_dict[section_name].items():
        gt_section_corefGroupNum_docId_dict[section_name][coref_num] = []
        for doc_id in gt_section_docId_dict[section_name]:
            if doc_id in doc_id_list:
                gt_section_corefGroupNum_docId_dict[section_name][coref_num].append(doc_id)

In [143]:
print("The data distribution of all docs:")
for section_name, corefGroupNum_docId_dict in all_section_corefGroupNum_docId_dict.items():
    print(section_name, [(coref_num, len(doc_id_list)) for coref_num, doc_id_list in corefGroupNum_docId_dict.items()])

print("\nThe data distribution of test docs (docs that had been selected for annotation):")
for section_name, corefGroupNum_docId_dict in gt_section_corefGroupNum_docId_dict.items():
    print(section_name, [(coref_num, len(doc_id_list)) for coref_num, doc_id_list in corefGroupNum_docId_dict.items()])

print("\nThe data distribution of all docs excluding test docs")
for section_name, corefGroupNum_docId_dict in all_section_corefGroupNum_docId_dict.items():
    print(section_name, [(coref_num, len(doc_id_list)-len(gt_section_corefGroupNum_docId_dict[section_name][coref_num])) for coref_num, doc_id_list in corefGroupNum_docId_dict.items()])

The data distribution of all docs:
findings [(1, 15956), (0, 136639), (2, 2624), (3, 597), (6, 12), (4, 145), (5, 37), (7, 1)]
impression [(0, 173932), (1, 12634), (2, 2197), (3, 561), (4, 110), (5, 22), (6, 6), (7, 2), (8, 1)]

The data distribution of test docs (docs that had been selected for annotation):
findings [(1, 43), (0, 25), (2, 22), (3, 7), (6, 0), (4, 2), (5, 1), (7, 0)]
impression [(0, 26), (1, 41), (2, 21), (3, 8), (4, 2), (5, 2), (6, 0), (7, 0), (8, 0)]

The data distribution of all docs excluding test docs
findings [(1, 15913), (0, 136614), (2, 2602), (3, 590), (6, 12), (4, 143), (5, 36), (7, 1)]
impression [(0, 173906), (1, 12593), (2, 2176), (3, 553), (4, 108), (5, 20), (6, 6), (7, 2), (8, 1)]


Design the sampling distribution for next round

In [144]:
section_corefGroupNum_sampleDocNum_dict = {
    "findings": {1: 20, 2: 20, 3: 7, 4: 2, 5: 1},
    "impression": {1: 20, 2: 20, 3: 7, 4: 2, 5: 1}
}

Get the target doc_ids for a new annotation procedure (also perform majority voting)

In [146]:
from hydra import compose, initialize
from omegaconf import OmegaConf
from nlp_ensemble.nlp_menbers import play_fastcoref

config = None
with initialize(version_base=None, config_path="../config", job_name="create_ann"):
        config = compose(config_name="coreference_resolution", overrides=["+coreference_resolution/coref_voting@_global_=mimic_cxr"])

In [147]:
from common_utils.coref_utils import shuffle_list
import coref_voting
from coref_voting import DocClass, MentionClass, compute_voting_result, get_output_df

mv_output_base_dir = os.path.join("../../output/mimic_cxr/coref_voting/majority_voting_sampling2")

def batch_processing3(config, spacy_file_path, section_name, file_name):
    """ Voting on one document """

    START_EVENT.wait()

    # Read spacy output as alignment base
    df_spacy = pd.read_csv(spacy_file_path, index_col=0, na_filter=False)
    # Some of the i2b2 raw files are utf-8 start with DOM, but we didn't remove the DOM character, thus we fix it here.
    df_spacy.iloc[0] = df_spacy.iloc[0].apply(lambda x: x.replace("\ufeff", "").replace("\xef\xbb\xbf", "") if isinstance(x, str) else x)

    docObj: DocClass = coref_voting.resolve_voting_info(config, df_spacy, section_name, file_name)
    valid_mention_group: list[set[MentionClass]] = compute_voting_result(config, docObj)
    df_out = get_output_df(config, df_spacy, valid_mention_group, docObj)

    output_dir = os.path.join(mv_output_base_dir, section_name)
    check_and_create_dirs(output_dir)
    output_file_path = os.path.join(output_dir, file_name)

    df_out.to_csv(output_file_path)

    return f"{file_name} done."

for section_name, corefGroupNum_sampleDocNum_dict in section_corefGroupNum_sampleDocNum_dict.items():
    # sample_num_str = "{1: 544, 2: 544, 3: 295, 4: 83, 5: 25, 6: 6, 7: 2, 8: 1}"
    with ProcessPoolExecutor(max_workers=config.thread.workers) as executor:
        all_task = []
        for groupNum, sampleDocNum in corefGroupNum_sampleDocNum_dict.items():
            # Get the acutal doc ids. Remove the doc_ids that used in testset. Then shuffle.
            docId_all_list = all_section_corefGroupNum_docId_dict[section_name][groupNum]
            docId_exclude_list = gt_section_corefGroupNum_docId_dict[section_name][groupNum]
            docId_list_excluded = [x for x in docId_all_list if x not in docId_exclude_list]
            docId_list_shuffle = shuffle_list(docId_list_excluded, 42)
            
            for doc_id in docId_list_shuffle[0:sampleDocNum]:
                file_name = doc_id + ".csv"
                spacy_out_dir = os.path.join(config.input.source.baseline_model.dir, section_name)
                spacy_file_path = os.path.join(spacy_out_dir, file_name)
                all_task.append(executor.submit(batch_processing3, config, spacy_file_path, section_name, file_name))
        
         # Notify tasks to start
        START_EVENT.set()

        if all_task:
            for future in tqdm(as_completed(all_task), total=len(all_task)):
                msg = future.result()

        executor.shutdown(wait=True, cancel_futures=False)
        START_EVENT.clear()

100%|██████████| 50/50 [00:54<00:00,  1.09s/it]
100%|██████████| 50/50 [00:48<00:00,  1.04it/s]


In [148]:
class AnnMentionClass:
    def __init__(self) -> None:
        self.id = ""
        self.type = "Mention"
        self.start_index = ""
        self.end_index = ""
        self.token_str_list = []
        self.token_str = ""
    
    def get_ann_str(self) -> str:
        return f"{self.id}\t{self.type} {self.start_index} {self.end_index}\t{' '.join(self.token_str_list)}\n"

    def set_end_index(self, value, text):
        self.end_index = value
        self.token_str = text[self.start_index:self.end_index]

    def __repr__(self) -> str:
        return self.get_ann_str()
    def __str__(self) -> str:
        return self.get_ann_str()

class AnnCoreferenceClass:
    def __init__(self) -> None:
        self.id = ""
        self.type = "Coreference"
        self.anaphora = ""
        self.antecedent = ""
    
    def get_ann_str(self) -> str:
        return f"{self.id}\t{self.type} Anaphora:{self.anaphora} Antecedent:{self.antecedent}\t\n"

    def __repr__(self) -> str:
        return self.get_ann_str()
    def __str__(self) -> str:
        return self.get_ann_str()


def get_AnnMentionClass_notClosed(ann_ment_list: list[AnnMentionClass]) -> AnnMentionClass:
    for annMent in ann_ment_list:
        if annMent.end_index == "":
            return annMent
    return None


In [149]:
brat_output_dir = os.path.join("../../output/mimic_cxr/brat_annotation_original/round3")

for section_name in ["findings", "impression"]:
    input_dir = os.path.join(mv_output_base_dir, section_name)

    for doc_id in [i.rstrip(".csv") for i in FILE_CHECKER.filter(os.listdir(input_dir))]:
        df_spacy = pd.read_csv(os.path.join(input_dir, f"{doc_id}.csv"), index_col=0, na_filter=False)
        df_sentence = df_spacy.groupby(['[sp]sentence_group'])['[sp]token'].apply('#@#'.join).reset_index()
        df_coref = df_spacy.groupby(['[sp]sentence_group'])["[mv]coref_group_conll"].apply(list).reset_index()

        sentences = [str(_series.get("[sp]token")).replace("#@#"," ").strip() for _, _series in df_sentence.iterrows()]
        text = "\n".join(sentences)
        # .txt files
        output_dir = os.path.join(brat_output_dir, section_name)
        check_and_create_dirs(output_dir)
        with open(os.path.join(output_dir, f"{doc_id}.txt"), "w", encoding="UTF-8") as f:
            f.write(text)

        # .ann files
        with open(os.path.join(output_dir, f"{doc_id}.ann"), "w", encoding="UTF-8") as f:
            mention_id = 0
            groupNum_mentions_dict:dict[int, list[AnnMentionClass]] = defaultdict(list)
            offset = 0
            for _idx, _series in df_sentence.iterrows():
                token_list = _series.get('[sp]token').strip().strip("#@#").split('#@#') # Remove the last whitespace
                conll_labelStr_list = df_coref.loc[_idx,].get("[mv]coref_group_conll") # The corresponding conll label of last whitespace are reamined so far
                
                for tok_id, tok in enumerate(token_list): # The corresponding conll label of last whitespace will be ignored 
                    conll_labelListStr = conll_labelStr_list[tok_id]
                    if conll_labelListStr not in [-1, "-1", np.nan]:
                        for conll_label in ast.literal_eval(conll_labelListStr):
                            if "(" in conll_label:
                                ann_mention_class = AnnMentionClass()
                                ann_mention_class.id = f"T{mention_id}"
                                ann_mention_class.start_index = offset
                                mention_id+=1
                                if ")" in conll_label:
                                    ann_mention_class.set_end_index(offset + len(tok), "\n".join(sentences))
                                    coref_id = int(conll_label.replace("(","").replace(")",""))
                                else:
                                    coref_id = int(conll_label.replace("(",""))
                                groupNum_mentions_dict[coref_id].append(ann_mention_class)
                            elif "(" not in conll_label and ")" in conll_label:
                                coref_id = int(conll_label.replace(")",""))
                                ann_mention_class = get_AnnMentionClass_notClosed(groupNum_mentions_dict[coref_id])
                                ann_mention_class.set_end_index(offset + len(tok), "\n".join(sentences))

                    offset += len(tok) + 1
                offset = sum([len(sent) for sent in sentences[0:_idx+1]]) + _idx + 1 # The offset of the sentence start.

            pair_id = 0
            for _, ann_mention_list in groupNum_mentions_dict.items():
                for _id, ann_mention_class in enumerate(ann_mention_list):
                    f.write(ann_mention_class.get_ann_str())

                for _id, ann_mention_class in enumerate(ann_mention_list):
                    if _id == 0:
                        continue
                    ann_coref_class = AnnCoreferenceClass()
                    ann_coref_class.id = f"R{pair_id}"
                    ann_coref_class.anaphora = ann_mention_list[_id-1].id
                    ann_coref_class.antecedent = ann_mention_list[_id].id

                    f.write(ann_coref_class.get_ann_str())
                    pair_id += 1