# Resolve the annotated data

Run prepare_ann_data.ipynb to get necessary resources

Put the annotated data into /resources/radgraph_plus

Check /src/config/graph_annotation_process/resolve_brat.yaml for settings


In [48]:
import sys
sys.path.append("../../src")
import os
import json
from collections import defaultdict

from common_utils.file_checker import FileChecker
from common_utils.common_utils import check_and_create_dirs, check_and_remove_file
FILE_CHECKER = FileChecker()

In [49]:
from hydra import compose, initialize
from omegaconf import OmegaConf

config = None
with initialize(version_base=None, config_path="../config", job_name="radgraph_to_brat"):
        config = compose(config_name="graph_annotation_process", overrides=["graph_annotation_process@_global_=resolve_brat"])
print(OmegaConf.to_yaml(config))

machine:
  work_dir: /home/yuxiangliao/PhD/workspace/VSCode_workspace/structured_reporting
  fast_coref_dir: /home/yuxiangliao/PhD/workspace/git_clone_repos/fast-coref
work_dir: ${machine.work_dir}
src_dir: ${work_dir}/src
output_dir: ${work_dir}/output
resource_dir: ${work_dir}/resources
base_output_dir: ${work_dir}/output
mimic_cxr_output_dir: ${base_output_dir}/mimic_cxr
log_dir: ${work_dir}/logs/${hydra.job.config_name}
logging_level: INFO
ignore_source_path: ${work_dir}/config/ignore/common_ignore
fast_coref_dir: ${machine.fast_coref_dir}
coref_scorer_dir: ${machine.fast_coref_dir}/coref_resources/reference-coreference-scorers
input:
  base_dir: ${resource_dir}/radgraph_plus
  annotator_L: ${input.base_dir}/annotator_L
  annotator_X: ${input.base_dir}/annotator_X
output:
  base_dir: ${base_output_dir}/radgraph_plus
brat_source:
  for_ann_dir: ${base_output_dir}/radgraph/brat_data_for_annotation



## Read the data

Have one annotator's result so far.

Automatically detect whether a report is annotated or not.

For .txt files, only the first line is the vaild reports. 
For .ann files, ignore the indices that are not in the first line.

### Get all files to be resolved

In [50]:
data_dict:dict[str,set] = defaultdict(set)
for root_path, dir_list, file_list in os.walk(config.input.annotator_L):
    for file_name in FILE_CHECKER.filter(file_list):
        file_name_prefix = file_name[:-4]
        data_dict[root_path].add(file_name_prefix)

### Check and get newly annotated data

In [51]:
# Some of the reports are not annotated yet, so we need to distinguish them.
def get_new_ann_results(root_path, file_prefix, ann_lines):
    old_label_list = get_old_labels(root_path,file_prefix)
    new_lines = [line for line in ann_lines if line.split("\t")[0] not in old_label_list]
    return new_lines
    
def get_old_labels(root_path,file_prefix):
    dataset_name = root_path.split(os.sep)[-2]
    datasplit_name = root_path.split(os.sep)[-1]
    old_label_dir = os.path.join(config.brat_source.for_ann_dir, dataset_name, "label_in_use", datasplit_name)
    with open(os.path.join(old_label_dir, file_prefix+".txt"),"r",encoding="utf-8") as f:
        old_labels = f.readlines()
    return [i.strip() for i in old_labels]

### Data classes and resolving functions

In [52]:
from abc import ABC, abstractmethod
import re

class AnnClass(ABC):
    @abstractmethod
    def __init__(self, _id, _type) -> None:
        self.id = _id
        self.type = _type
        
    @classmethod
    @abstractmethod
    def resolve_line(cls, ann_line:str) -> None:
        pass
    
    @abstractmethod
    def get_ann_str(self) -> str:
        pass
    
    @abstractmethod
    def get_json_dict(self) -> dict:
        pass
    
    def __repr__(self) -> str:
        return self.get_ann_str()

class AnnEntityClass(AnnClass):
    def __init__(self, _id, _type, _start_index, _end_index, _token_str) -> None:
        super().__init__(_id, _type)
        self.start_index = _start_index
        self.end_index = _end_index
        self.token_str = _token_str
    
    @classmethod
    def resolve_line(cls, ann_line:str):
        pattern = r"(.*)\t(.*) (\d*) (\d*)\t(.*)"
        match_obj = re.match(pattern, ann_line.strip())
        return cls(*match_obj.groups())
        

    def get_ann_str(self) -> str:
        return f"{self.id}\t{self.type} {self.start_index} {self.end_index}\t{self.token_str}\n"

    def get_json_dict(self) -> dict:
        return {
            "tokens": self.token_str,
            "label": self.type,
            "start_idx": self.start_index,
            "end_index": self.end_index
        }
        
    def __eq__(self, o):
        if isinstance(o, AnnEntityClass):
            return self.id == o.id
        else:
            return self.id == o
    
    def __hash__(self):
        return hash(self.id)

class AnnRelationClass(AnnClass):
    def __init__(self, _id, _type, _arg1, _arg2) -> None:
        super().__init__(_id, _type)
        self.arg1 = _arg1
        self.arg2 = _arg2
        self._sorted_arg1 = None
        self._sorted_arg2 = None
    
    @classmethod
    def resolve_line(cls, ann_line:str):
        pattern = r"(.*)\t(.*) Arg1:(.*) Arg2:(.*)"
        match_obj = re.match(pattern, ann_line.strip())
        return cls(*match_obj.groups())

    def get_ann_str(self) -> str:
        return f"{self.id}\t{self.type} Arg1:{self.arg1} Arg2:{self.arg2}\t\n"
    
    def get_json_dict(self) -> dict:
        return {
            "label": self.type,
            "entity1": self._sorted_arg1,
            "entity2": self._sorted_arg2
        }

class AnnAttributeClass(AnnClass):
    def __init__(self, _id, _type, _referred_id, _type_content="") -> None:
        super().__init__(_id, _type)
        self.referred_id = _referred_id
        self.type_content = _type_content
        self._sorted_referred_id = None
        
    @classmethod
    def resolve_line(cls, ann_line:str):
        pattern = r"(.*)\t(.*)"
        match_obj = re.match(pattern, ann_line.strip())
        return cls(match_obj.group(1), *match_obj.group(2).split(" "))
        
    def get_ann_str(self) -> str:
        return f"{self.id}\t{self.type} {self.referred_id} {self.type_content}\n"
    
    def get_json_dict(self) -> dict:
        return {
            "label": f"{self.type}:{self.type_content}" if self.type_content else self.type,
            "entity": self._sorted_referred_id
        }

In [53]:
import traceback

def resolve_lines(ann_lines:list) -> tuple[list,list,list]:
    entity_obj_list = []
    relation_obj_list = []
    attribute_obj_list = []
    for line in ann_lines:
        try:
            if line[0] == "T":
                entity_obj_list.append(AnnEntityClass.resolve_line(line))
            elif line[0] == "R":
                relation_obj_list.append(AnnRelationClass.resolve_line(line))
            elif line[0] == "A":
                attribute_obj_list.append(AnnAttributeClass.resolve_line(line))
            else:
                raise ValueError("Not recoginzed line", line)
        except AttributeError as e:
            traceback.print_exc()
            print(line)
    return entity_obj_list, relation_obj_list, attribute_obj_list

def check_index_clash(report:str, entity_obj_list:list[AnnEntityClass]):
    for entity_obj in entity_obj_list:
        if entity_obj.token_str != report[int(entity_obj.start_index):int(entity_obj.end_index)]:
            raise ValueError(f"`{entity_obj.token_str}` != `{report[int(entity_obj.start_index):int(entity_obj.end_index)]}` at [{entity_obj.start_index}:{entity_obj.end_index}]")
    return True

### Convert class obj to json



In [54]:
import json
def format_output(report:str, entity_obj_list:list[AnnEntityClass], relation_obj_list:list[AnnRelationClass], attribute_obj_list:list[AnnAttributeClass]):
    report_dict = dict()
    report_dict["text"] = report
    annotation_dict = dict()
    report_dict["labeler_1"] = annotation_dict
    sorted_entity_obj_list = sorted(entity_obj_list, key=lambda obj: (int(obj.start_index), int(obj.end_index)))
    update_sorted_entity_idx_to_objs(sorted_entity_obj_list, relation_obj_list, attribute_obj_list)
    annotation_dict["entities"] = ann_objs_to_dict(sorted_entity_obj_list)
    annotation_dict["relations"] = ann_objs_to_dict(sorted(relation_obj_list, key=lambda obj: (int(obj.arg1[1:]), int(obj.arg2[1:]))))
    annotation_dict["attributes"] = ann_objs_to_dict(sorted(attribute_obj_list, key=lambda obj: int(obj.referred_id[1:])))
    return report_dict

def update_sorted_entity_idx_to_objs(sorted_entity_obj_list:list[AnnEntityClass], relation_obj_list:list[AnnRelationClass], attribute_obj_list:list[AnnAttributeClass]):
    for obj in relation_obj_list:
        obj._sorted_arg1 = str(sorted_entity_obj_list.index(obj.arg1))
        obj._sorted_arg2 = str(sorted_entity_obj_list.index(obj.arg2))
    for obj in attribute_obj_list:
        obj._sorted_referred_id = str(sorted_entity_obj_list.index(obj.referred_id))

def ann_objs_to_dict(sorted_ann_obj_list:list[AnnClass]):
    temp_dict = dict()
    for i, obj in enumerate(sorted_ann_obj_list):
        temp_dict[str(i)] = obj.get_json_dict()
    return temp_dict

### Run

In [55]:
processing_progress = defaultdict(int)
for root_path, file_prefix_list in data_dict.items():
    datasplit_name = os.path.basename(root_path)
    dataset_name = os.path.basename(os.path.dirname(root_path))
    output_dir = os.path.join(config.output.base_dir, dataset_name)
    check_and_create_dirs(output_dir)
    # check_and_remove_file(os.path.join(output_dir, datasplit_name+".json")) # Re-create files
    data_split_dict = {}
    for file_prefix in file_prefix_list:
        # Read ann file
        with open(os.path.join(root_path, file_prefix+".ann"),"r",encoding="utf-8") as f_ann:
            ann_lines = f_ann.readlines()
        new_ann_lines = get_new_ann_results(root_path, file_prefix, ann_lines)
        # When the ann file has new annotation data.
        if new_ann_lines:
            entity_obj_list, relation_obj_list, attribute_obj_list = resolve_lines(new_ann_lines)
            # Read txt file
            with open(os.path.join(root_path, file_prefix+".txt"),"r",encoding="utf-8") as f_ann:
                report = f_ann.readline()
                report = report.strip()
            check_index_clash(report, entity_obj_list) # raise ValueError if not valid
            report_dict = format_output(report,entity_obj_list, relation_obj_list, attribute_obj_list)
            report_dict["data_source"] = dataset_name
            report_dict["data_split"] = datasplit_name
            data_split_dict[file_prefix] = report_dict
            processing_progress[f"{dataset_name}_{datasplit_name}"] += 1
    with open(os.path.join(output_dir, datasplit_name+".json"), "w", encoding="utf-8") as f:
        f.write(json.dumps(data_split_dict,indent=4))
print(processing_progress)

defaultdict(<class 'int'>, {'MIMIC-CXR_test': 50, 'MIMIC-CXR_dev': 20})
