# Resolve the annotated data

We resolve the annotated data from BRAT to json (ner + re).

## Prepare

1. Set `brat_data_dir` to the annotated data root dir
2. Run the script

In [11]:
brat_data_dir = "/Users/liao/myProjects/repo/remote_brat/data/structured_reporting/ours/liao"
to_be_annotated_dir = "/Users/liao/myProjects/VSCode_workspace/cxr_graph/graph_annotation_process/outputs/to_be_annotated"
output_root_dir = "./outputs/cxr_graph/json4ner_re"

In [12]:
import os
import json
import re
import copy
from collections import defaultdict

In [13]:
import shutil

if os.path.exists(output_root_dir):
    shutil.rmtree(output_root_dir)

os.makedirs(output_root_dir)

## Resolve BRAT result to json

In [14]:
class AnnEntityClass:
    def __init__(self, stripped_str) -> None:
        self.brat_id = ""  # T0
        self.label = ""
        self.start_index = -1  # include (char idx)
        self.end_index = -1  # not include
        self.token_str = ""
        self.att_objs = []

        self.id = ""  # E0
        self.start_token_idx = -1 # include
        self.end_token_idx = -1 # include
        self.sent_idx = -1

        self.type = ""  # ANAT, OBS, LOCATT
        self.chain_info = {
            "modify": {"in": [], "out": []},
            "part_of": {"in": [], "out": []},
            "located_at": {"in": [], "out": []},
            "suggestive_of": {"in": [], "out": []},
        }
        self.resolve(stripped_str)
        
        self.abnormality = "NA"
        self.action = "NA"
        self.evolution = "NA"
        

    def get_ann_str(self) -> str:
        return f"{self.brat_id}\t{self.label} {self.start_index} {self.end_index}\t{self.token_str}\n"

    def resolve(self, stripped_str):
        patten = r"(T\d+)\t(.+) (\d+) (\d+)\t(.+)"
        match_obj = re.match(patten, stripped_str)
        obs_labels = ["Observation-Present", "Observation-Absent", "Observation-Uncertain"]
        if match_obj:
            self.brat_id, self.label, start_index, end_index, self.token_str = match_obj.groups()
            self.start_index = int(start_index)
            self.end_index = int(end_index)
            if self.label in obs_labels:
                self.type = "OBS"
            elif self.label == "Anatomy":                
                self.type = "ANAT"
            elif self.label == "Location-Attribute":
                self.type = "LOCATT"
            else:
                raise ValueError(f"Cannot identify: {self.label}")
        else:
            raise ValueError(f"Cannot resolve: {stripped_str}")

    def __repr__(self) -> str:
        return self.get_ann_str()

    def __str__(self) -> str:
        return self.get_ann_str()

    def __eq__(self, other):
        if isinstance(other, AnnEntityClass):
            return self.brat_id == other.brat_id
        else:
            return other == self.brat_id

    def __hash__(self):
        return hash(self.brat_id)


class AnnRelationClass:
    def __init__(self, stripped_str) -> None:
        self.brat_id = ""  # R0
        self.label = ""
        self.arg1 = ""  # from entity: T0
        self.arg2 = ""  # to entity: T1
        self.resolve(stripped_str)

        self.id = ""  # R0

    def get_ann_str(self) -> str:
        return f"{self.brat_id}\t{self.label} Arg1:{self.arg1} Arg2:{self.arg2}\t\n"

    def __repr__(self) -> str:
        return self.get_ann_str()

    def __str__(self) -> str:
        return self.get_ann_str()

    def resolve(self, stripped_str):
        patten = r"(R\d+)\t(.+) Arg1:(T\d+) Arg2:(T\d+)"
        match_obj = re.match(patten, stripped_str)
        if match_obj:
            self.brat_id, self.label, self.arg1, self.arg2 = match_obj.groups()
        else:
            raise ValueError(f"Cannot resolve: {stripped_str}")

    def __eq__(self, other):
        if isinstance(other, AnnRelationClass):
            return self.brat_id == other.brat_id
        else:
            return other == self.brat_id

    def __hash__(self):
        return hash(self.brat_id)


class AnnAttributeClass:
    def __init__(self, stripped_str) -> None:
        self.brat_id = ""  # A0
        self.label = ""
        self.value = ""
        self.target_entity_id = ""  # T0
        
        self.resolve(stripped_str)

    def get_ann_str(self) -> str:
        if self.value:
            return f"{self.brat_id}\t{self.label} {self.target_entity_id} {self.value}"
        else:
            return f"{self.brat_id}\t{self.label} {self.target_entity_id}"

    def get_json_str(self) -> str:
        if self.label == "isAbnormal_OBS":
            return "is_abnormal"
        if self.label == "isNormal_OBS":
            return "is_normal"
        if self.label == "Uncertian_Tendency":
            return f"uncertainty:{self.value}"
            raise ValueError("Should not have this attribute")
        if self.label == "isRelative_Modifier":
            return f"is_relative_modifier:{self.value}"
        if self.label == "show_RelativeChange":
            return f"has_relative_change:{self.value}"

    def __repr__(self) -> str:
        return self.get_ann_str()

    def __str__(self) -> str:
        return self.get_ann_str()

    def resolve(self, stripped_str):
        patten = r"(A\d+)\t(.+) (T\d+) ?(.+)?"
        match_obj = re.match(patten, stripped_str)
        if match_obj:
            self.brat_id, self.label, self.target_entity_id, self.value = match_obj.groups()
        else:
            raise ValueError(f"Cannot resolve: {stripped_str}")
        
    def __eq__(self, other):
        if isinstance(other, AnnAttributeClass):
            return self.brat_id == other.brat_id
        else:
            return other == self.brat_id or other == self.target_entity_id

    def __hash__(self):
        return hash(self.brat_id)

In [15]:
cross_sent_relations = defaultdict(list)
def bart2json(dataset_names, datasplits, output_name):
    for dataset_name in dataset_names: # , "CheXpert"
        for datasplit in datasplits:
            for file_name in os.listdir(os.path.join(to_be_annotated_dir, dataset_name, "label_in_use", datasplit)):
                if dataset_name == "MIMIC-CXR":
                    doc_key = file_name.lstrip(f"{dataset_name}_").replace("_", "/")
                if dataset_name == "CheXpert":
                    doc_key = file_name.lstrip(f"{dataset_name}_").rstrip(".txt")

                txt_file_name = file_name
                ann_file_name = f'{file_name.rstrip(".txt")}.ann'

                txt_file = os.path.join(brat_data_dir, dataset_name, datasplit, txt_file_name)
                ann_file = os.path.join(brat_data_dir, dataset_name, datasplit, ann_file_name)

                output_dict = {
                    "doc_key": doc_key,
                    "sentences": [],
                    "ner": [],
                    "relations": [],
                    "entity_attributes": [],
                }

                # 读取原始doc：只读取第一行
                with open(txt_file, "r", encoding="utf-8") as f:
                    doc_str = f.readline().strip()

                # 超过这个范围的标签都应该排除（因为我们把RadGraph的标签也一起呈现给了标注者，所以解析时需要排除这些已有的标签）
                valid_doc_len = len(doc_str)

                # 读取标签
                with open(ann_file, "r", encoding="utf-8") as f:
                    ann_lines = f.readlines()
                    # print(ann_lines)

                ent_obj_list = []
                rel_obj_list = []
                att_obj_list = []
                for ann_line in ann_lines:
                    stripped_ann_line = ann_line.strip()
                    if stripped_ann_line.startswith("T"):
                        ent = AnnEntityClass(stripped_ann_line)
                        ent_obj_list.append(ent)
                    elif stripped_ann_line.startswith("R"):
                        rel = AnnRelationClass(stripped_ann_line)
                        rel_obj_list.append(rel)
                    elif stripped_ann_line.startswith("A"):
                        att = AnnAttributeClass(stripped_ann_line)
                        att_obj_list.append(att)
                        ent = ent_obj_list[ent_obj_list.index(att.target_entity_id)]
                        ent.att_objs.append(att)
                        if att.label == "isAbnormal_OBS":
                            ent.abnormality = "Abnormal"
                        if att.label == "isNormal_OBS":
                            ent.abnormality = "Normal"
                        if att.label == "isRelative_Modifier":
                            ent.action = att.value
                        if att.label == "show_RelativeChange":
                            ent.evolution = att.value
                    else:
                        raise ValueError(f"Uncatched value from .ann file: {stripped_ann_line}")
                
                ent_obj_list = list(filter(lambda ent: ent.start_index <= valid_doc_len and ent.end_index <= valid_doc_len, ent_obj_list))
                rel_obj_list = list(filter(lambda rel: rel.arg1 in ent_obj_list and rel.arg2 in ent_obj_list, rel_obj_list))
                att_obj_list = list(filter(lambda att: att.target_entity_id in ent_obj_list in ent_obj_list, att_obj_list))

                # 识别token的位置，并添加token_idx; 按句子拆分
                doc_tokens = doc_str.split(" ")
                token_start_idx_list = [] # token first char
                token_end_idx_list = [] # token last char + 1
                curr_start = 0
                
                sent_idx = 0
                tokidx2sentidx = []
                sent = []
                for tok_idx, token_str in enumerate(doc_tokens):
                    # 识别token的位置，并添加token_idx
                    token_start_idx_list.append(curr_start)
                    token_end_idx_list.append(curr_start + len(token_str))
                    curr_start += len(token_str) + 1 # whitespace
                    
                    # 按句子拆分
                    tokidx2sentidx.append(sent_idx)
                    sent.append(token_str)
                    if token_str == "." or tok_idx == len(doc_tokens) - 1:
                        output_dict["sentences"].append(sent)
                        output_dict["ner"].append([])
                        output_dict["relations"].append([])
                        output_dict["entity_attributes"].append([])
                        sent_idx += 1
                        sent = []
                assert len(doc_tokens) == len([i for sent in output_dict["sentences"] for i in sent])

                for ent in ent_obj_list:
                    ent.start_token_idx = token_start_idx_list.index(ent.start_index)
                    ent.end_token_idx = token_end_idx_list.index(ent.end_index)
                    assert ent.token_str == " ".join(doc_tokens[ent.start_token_idx : ent.end_token_idx + 1])
                    
                    starttok_sent_idx = tokidx2sentidx[ent.start_token_idx]
                    endtok_sent_idx = tokidx2sentidx[ent.end_token_idx]
                    ent.sent_idx = starttok_sent_idx
                    assert starttok_sent_idx == endtok_sent_idx

                # Entity
                for ent_id, ent in enumerate(sorted(ent_obj_list, key=lambda x: x.start_token_idx)):
                    output_dict["ner"][ent.sent_idx].append([ent.start_token_idx, ent.end_token_idx, ent.label])
                    
                    # Attribute
                    if ent.att_objs:
                        output_dict["entity_attributes"][ent.sent_idx].append([ent.start_token_idx, ent.end_token_idx, ent.abnormality, ent.action, ent.evolution])

                # Relation
                for rel_id, rel in enumerate(sorted(rel_obj_list, key=lambda x: ent_obj_list[ent_obj_list.index(x.arg1)].start_token_idx)):
                    subj = ent_obj_list[ent_obj_list.index(rel.arg1)]
                    obj = ent_obj_list[ent_obj_list.index(rel.arg2)]
                    output_dict["relations"][subj.sent_idx].append([subj.start_token_idx, subj.end_token_idx, obj.start_token_idx, obj.end_token_idx, rel.label])
                    cross_sent_relations[f"{dataset_name}_{datasplit}_{output_name}"].append(abs(subj.sent_idx - obj.sent_idx))

                output_path = os.path.join(output_root_dir, output_name)
                with open(output_path, "a", encoding="utf-8") as f:
                    f.write(json.dumps(output_dict))
                    f.write("\n")

In [16]:
bart2json(dataset_names=["MIMIC-CXR"], datasplits=["train"], output_name="train.json")
bart2json(dataset_names=["MIMIC-CXR"], datasplits=["dev"], output_name="dev.json")
bart2json(dataset_names=["MIMIC-CXR", "CheXpert"], datasplits=["test"], output_name="test.json")
bart2json(dataset_names=["MIMIC-CXR"], datasplits=["test"], output_name="test_mimic.json")
bart2json(dataset_names=["CheXpert"], datasplits=["test"], output_name="test_chexpert.json")

In [17]:
from collections import Counter

for k, v in cross_sent_relations.items():
    print(k)
    c = Counter(v)
    print("Cross sentence relations: (|subj_sentid - obj_sentid|, num_relation_pairs):")
    print(c.most_common())
    print()

MIMIC-CXR_train_train.json
Cross sentence relations: (|subj_sentid - obj_sentid|, num_relation_pairs):
[(0, 9554), (1, 66)]

MIMIC-CXR_dev_dev.json
Cross sentence relations: (|subj_sentid - obj_sentid|, num_relation_pairs):
[(0, 1706), (1, 11)]

MIMIC-CXR_test_test.json
Cross sentence relations: (|subj_sentid - obj_sentid|, num_relation_pairs):
[(0, 984), (1, 3)]

CheXpert_test_test.json
Cross sentence relations: (|subj_sentid - obj_sentid|, num_relation_pairs):
[(0, 1138), (1, 12), (2, 1)]

MIMIC-CXR_test_test_mimic.json
Cross sentence relations: (|subj_sentid - obj_sentid|, num_relation_pairs):
[(0, 984), (1, 3)]

CheXpert_test_test_chexpert.json
Cross sentence relations: (|subj_sentid - obj_sentid|, num_relation_pairs):
[(0, 1138), (1, 12), (2, 1)]

