In [1]:
import sys
sys.path.append("../../src")

import os
from tqdm import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from concurrent.futures import ProcessPoolExecutor, as_completed
from multiprocessing import Event
from common_utils.data_loader_utils import load_mimic_cxr_bySection
from common_utils.coref_utils import resolve_mention_and_group_num
from common_utils.file_checker import FileChecker
from common_utils.common_utils import check_and_create_dirs, check_and_remove_dirs
from preprocess_i2b2 import aggregrate_files, I2b2Token, get_file_name_prefix, clean_and_split_line

FILE_CHECKER = FileChecker()
START_EVENT = Event()

In [2]:
from hydra import compose, initialize
from omegaconf import OmegaConf

config = None
with initialize(version_base=None, config_path="../config", job_name="i2b2_for_brat"):
        config = compose(config_name="data_preprocessing", overrides=["data_preprocessing@_global_=i2b2"])

# Aggregrate files

In [3]:
temp_dir = os.path.join(config.temp.dir)
check_and_create_dirs(temp_dir)
docs_dir, chains_dir = aggregrate_files(config, temp_dir)

check_and_remove_dirs(os.path.join(temp_dir, "docs",".ipynb_checkpoints"), True)
# Check that the files are matched.
doc_files = os.listdir(docs_dir)
chain_files = os.listdir(chains_dir)
assert len(doc_files) == len(chain_files)


# Resolve files

In [4]:
class AnnMentionClass:
    def __init__(self) -> None:
        self.id = ""
        self.type = "Mention"
        self.start_index = ""
        self.end_index = ""
        self.token_str_list = []
    
    def get_ann_str(self) -> str:
        return f"{self.id}\t{self.type} {self.start_index} {self.end_index}\t{' '.join(self.token_str_list)}\n"

class AnnCoreferenceClass:
    def __init__(self) -> None:
        self.id = ""
        self.type = "Coreference"
        self.anaphora = ""
        self.antecedent = ""
    
    def get_ann_str(self) -> str:
        return f"{self.id}\t{self.type} Anaphora:{self.anaphora} Antecedent:{self.antecedent}\t\n"

In [5]:
# Process each files
def batch_processing(doc_file_path, chain_file_path):
    """ Resolve a single i2b2 document, including a .txt file and a .chains file. """
    doc_id = get_file_name_prefix(doc_file_path, ".txt")

    # Resolve doc file
    sentence_list: list[list[I2b2Token]] = []
    with open(doc_file_path, "r", encoding="UTF-8-sig") as doc:
        tokenId_docwise = 0
        for sentence_id, doc_line in enumerate(doc.readlines()):
            token_list: list[I2b2Token] = []
            for tokenId_sentencewise, token_str in enumerate(clean_and_split_line(doc_line, debug_doc=doc_id, debug_sent=sentence_id)):
                i2b2_token = I2b2Token(doc_id, sentence_id, tokenId_sentencewise, tokenId_docwise, token_str)
                token_list.append(i2b2_token)
                tokenId_docwise += 1
            sentence_list.append(token_list)

    # Resolve chain file (coref cluster)
    with open(chain_file_path, "r", encoding="UTF-8-sig") as chain:
        coref_group_list:list[list[list[I2b2Token]]] = []
        for cluster_id, cluster in enumerate(chain.readlines()):
            coref_group:list[list[I2b2Token]] = []

            for coref in cluster.split("||")[0:-1]:  # Drop the last one, which is the type of the coref
                mention_list:list[I2b2Token] = []
                token_range: list[str, str] = coref.split(" ")[-2:]
                start = token_range[0]
                end = token_range[1]
                if start == end:
                    sentId, tokId = start.split(":")
                    mention_list.append(sentence_list[int(sentId) - 1][int(tokId)])
                else:
                    sentId_start, tokId_start = start.split(":")
                    sentId_end, tokId_end = end.split(":")
                    if sentId_start == sentId_end:
                        mention_list.extend(sentence_list[int(sentId_start) - 1][int(tokId_start):int(tokId_end)+1])
                    else:
                        temp_list = sentence_list[int(sentId_start) - 1][int(tokId_start):]
                        i = 1
                        while int(sentId_start) + i <= int(tokId_end):
                            if int(sentId_start) + i == int(tokId_end):
                                temp_list.extend(sentence_list[int(sentId_start + i) - 1][:int(tokId_end)+1])
                            else:
                                temp_list.extend(sentence_list[int(sentId_start + i) - 1][:])
                            i += 1
                        mention_list.extend(temp_list)
            
                coref_group.append(mention_list)

            coref_group_list.append(coref_group)



    output_dir = os.path.join(config.output_base_dir, "brat_visualization")
    check_and_create_dirs(output_dir)
    offset = 0
    with open(os.path.join(output_dir,f"{doc_id}.txt"), "w", encoding="UTF-8") as f:
        docStr_list:list[str] = []
        for sentence in sentence_list:

            sentenceStr_list:list[str] = []
            for i2b2_toekn in sentence:
                i2b2_toekn.offset = offset
                offset += len(i2b2_toekn.tokenStr) + 1
                sentenceStr_list.append(i2b2_toekn.tokenStr)

            docStr_list.append(" ".join(sentenceStr_list))
        
        f.write("\n".join(docStr_list)) # Offset is correct, as there is no tralling whitespaces but have \n


    mention_id, pair_id = 0, 0
    with open(os.path.join(output_dir,f"{doc_id}.ann"), "w", encoding="UTF-8") as f:
        for _coref_list in coref_group_list:

            ann_mention_list:list[AnnMentionClass] = []
            for _mention_list in _coref_list:

                ann_mention_class = AnnMentionClass()
                ann_mention_class.id = f"T{mention_id}"
                ann_mention_class.start_index = _mention_list[0].offset
                ann_mention_class.end_index = _mention_list[-1].offset + len(_mention_list[-1].tokenStr)
                ann_mention_class.token_str_list.append(_mention_list[0].tokenStr)

                if len(_mention_list) > 1:
                    for i2b2_token in _mention_list[1:]:
                        ann_mention_class.token_str_list.append(i2b2_token.tokenStr)
                
                f.write(ann_mention_class.get_ann_str())
                mention_id+=1

                ann_mention_list.append(ann_mention_class)

            for _id, ann_mention_class in enumerate(ann_mention_list):
                if _id == 0:
                    continue

                ann_coref_class = AnnCoreferenceClass()
                ann_coref_class.id = f"R{pair_id}"
                ann_coref_class.anaphora = ann_mention_list[_id-1].id
                ann_coref_class.antecedent = ann_mention_list[_id].id
                
                f.write(ann_coref_class.get_ann_str())
                pair_id += 1


all_task = []
with ProcessPoolExecutor(max_workers=1) as executor:
    # Submit task
    for _file_name in tqdm(doc_files):
        # Input files
        doc_file_path = os.path.join(docs_dir, _file_name)
        chain_file_path = os.path.join(chains_dir, _file_name + config.input.chain_suffix)
        all_task.append(executor.submit(batch_processing, doc_file_path, chain_file_path))

    # Notify tasks to start
    START_EVENT.set()

    # When a submitted task finished, the output is received here.
    if all_task:
        for future in tqdm(as_completed(all_task), total=len(all_task)):
            future.result()

    START_EVENT.clear()

100%|██████████| 424/424 [00:00<00:00, 25748.32it/s]
100%|██████████| 424/424 [00:01<00:00, 354.83it/s]
