# Init

In [1]:
import datasets
from datasets import load_dataset, Sequence, Image, DatasetDict, concatenate_datasets, Dataset
import os
import json
from tqdm import tqdm
import re
import copy
import pandas as pd
import numpy as np
from typing import Union, List
import ast
import linecache
from collections import defaultdict
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

  from .autonotebook import tqdm as notebook_tqdm


# Load Ontology

In [2]:
class OntologyNode:
    def __init__(self, row_idx, class_id, class_name, df_row):
        self.row_idx = row_idx
        self.class_id = class_id
        self.class_name = class_name
        self.synonyms = [] if df_row["Synonyms"] == "" else df_row["Synonyms"].split("|")
        self.df_row = df_row

        # The tree structure is maintained by the parent and children attributes. Only one level of parent-child relationship is maintained.
        self.parent = []
        self.children = []
        self.is_root = False
        self.tree_level = None

        # It's parents from all levels
        self._all_parents = []

    def add_child(self, child):
        self.children.append(child)

    def add_parent(self, parent):
        self.parent.append(parent)

    @property
    def all_parents(self):
        if self.is_root:
            return []
        elif self._all_parents:
            return self._all_parents
        else:
            for parent in self.parent:
                # 避免父节点重复
                self._all_parents = set(parent.all_parents + [parent])
                self._all_parents = list(self._all_parents)
            return self._all_parents

    def __eq__(self, other):
        if isinstance(other, OntologyNode):
            return self.class_id == other.class_id
        else:
            return self.class_id == other

    def __hash__(self):
        return hash(self.class_id)

    def __str__(self):
        return f"{self.class_id}: {self.class_name}"

    def __repr__(self):
        return self.__str__()


def set_tree_level(curr_node, tree_level):
    curr_node.tree_level = tree_level
    for child in curr_node.children:
        set_tree_level(child, tree_level + 1)
    if not curr_node.children:
        return

In [3]:
def build_radlex_tree(df_csv):
    # Build a RadLex node list
    node_list = []
    root_node = None
    for idx, row in tqdm(df_csv.iterrows(), total=df_csv.shape[0], desc="Building RadLex tree"):
        ontology_node = OntologyNode(row_idx=idx, class_id=row["Class ID"], class_name=row["Preferred Label"], df_row=row)
        if row["Preferred Label"] in row["Class ID"]:
            ontology_node.class_name = row["http://radlex.org/RID/Preferred_Name_for_Obsolete"]
        node_list.append(ontology_node)

    # Resolve the node list and build a RadLex tree
    for node in tqdm(node_list, total=len(node_list), desc="Building RadLex tree"):
        df_row = node.df_row
        parent_ids = df_row["Parents"].split("|")
        for parent_id in parent_ids:
            parent_row_indices = df_csv.loc[df_csv["Class ID"] == parent_id].index
            if not parent_row_indices.empty:
                parent_row_idx = parent_row_indices[0]
                parent_node = node_list[parent_row_idx]
                assert parent_node.class_id == parent_id
                node.add_parent(parent_node)
                parent_node.add_child(node)
            else:
                # In radlex, http://radlex.org/RID/RID0 has parent http://www.w3.org/2002/07/owl#Thing.
                # However, the RID0 is already the root node in the RadLex ontology. We can safely ignore the owl#Thing.
                root_node = node
                node.is_root = True
                node.tree_level = 0

    return node_list, root_node

In [4]:
radlex_csv_path = "/home/yuxiang/liao/resources/bioportal/radlex/RADLEX.csv"
df_radlex_csv = pd.read_csv(radlex_csv_path, keep_default_na=False)
radlex_nodes, radlex_root_node = build_radlex_tree(df_radlex_csv)
radlex_nodes_dict = {node.class_id: node for node in radlex_nodes}
print(f"Number of RadLex nodes: {len(radlex_nodes)}")

# Tracing all parents of nodes
for node in radlex_nodes:
    node.all_parents

set_tree_level(radlex_root_node, tree_level=0)

Building RadLex tree: 100%|██████████| 46761/46761 [00:02<00:00, 22832.44it/s]
Building RadLex tree: 100%|██████████| 46761/46761 [01:10<00:00, 665.98it/s]


Number of RadLex nodes: 46761


In [5]:
print(radlex_nodes[0])
print(radlex_nodes[0].parent)
print(sorted(radlex_nodes[0].all_parents, key=lambda x: x.tree_level, reverse=True))

http://radlex.org/RID/RID35591: string-of-pearls sign of bowel
[http://radlex.org/RID/RID29023: imaging sign]
[http://radlex.org/RID/RID29023: imaging sign, http://radlex.org/RID/RID5: imaging observation, http://radlex.org/RID/RID1: RadLex entity, http://radlex.org/RID/RID0: RadLex ontology entity]


# Re-orgainze dataset

从effusion开始: http://radlex.org/RID/RID4872

仅保留doc中包含effusion的split_sent

探索这些句子所包含的graph模式

探索句子的共指

探索句子的attr对训练是否有影响

按照img dataset的格式重新组织数据集

In [6]:
data_path = "/home/yuxiang/liao/workspace/arrg_preprocessing/outputs/interpret_sents/combined_results/combine_final"
ds_text = datasets.load_from_disk(data_path)

目标是获取图像的同时也能获取文字
但如果要把文字合并到数据集中的话性能不够
如果文字单独保存的话，要怎么读取？
比如用selectid

可以先读取text数据集，然后用doc_key解析出img数据集，然后用select选择图像
可行

In [7]:
# radlex节点与doc_key的映射

radlex_dockey_dict = defaultdict(list)
radlex_sentkey_dict = defaultdict(list)

for data_row in tqdm(ds_text):
    doc_key = data_row["doc_key"]  # train#0#impression
    for sent_idx, sent in enumerate(data_row["radlex"]):
        for radlex_item in sent:
            # {"char_indices": [20, 31], "match_type": "fuzzy_lemma", "matched_text": "parenchymal", "radlex_id": "http://radlex.org/RID/RID5978", "radlex_name": "parenchyma", "tok_indices": [2, 3]}
            radlex_node = radlex_nodes_dict[radlex_item["radlex_id"]]
            radlex_dockey_dict[radlex_node].append(doc_key)
            radlex_sentkey_dict[radlex_node].append((doc_key, sent_idx))

100%|██████████| 1136366/1136366 [10:34<00:00, 1791.09it/s]


In [8]:
# 处理后的文本数据集与原始的图像数据集的映射
doc_key_map = {
    "findings": {"train": {}, "validation": {}, "test": {}},
    "impression": {"train": {}, "validation": {}, "test": {}},
}

for textDs_row_idx, data_row in enumerate(tqdm(ds_text.select_columns(["doc_key"]))):
    doc_key = data_row["doc_key"]  # train#0#impression
    data_split, imgDs_row_idx, section_name = doc_key.split("#")

    doc_key_map[section_name][data_split][imgDs_row_idx] = textDs_row_idx

100%|██████████| 1136366/1136366 [00:09<00:00, 124860.90it/s]


In [9]:
# 获取目标radlex节点所涉及的的doc_keys

inclusive = {
    "http://radlex.org/RID/RID5": "imaging observation",
    "http://radlex.org/RID/RID34785": "clinical finding",
    "http://radlex.org/RID/RID34861": "object",
    "http://radlex.org/RID/RID1559": "procedure",
    "http://radlex.org/RID/RID35977": "property",
    "http://radlex.org/RID/RID3": "anatomical entity",
    "http://radlex.org/RID/RID6": "RadLex descriptor",
}

# 过滤，仅保留inclusive相关的radlex节点
node_dockey_dict = defaultdict(set)
for node, dockeys in radlex_dockey_dict.items():
    # Check if the node or its parents is in the inclusive list
    if any([cls_id in node.all_parents or cls_id == node.class_id for cls_id in inclusive.keys()]):
        node_dockey_dict[node].update(dockeys)

nodes = [node for node in node_dockey_dict.keys()]
dockey_sets = [set(ids) for ids in node_dockey_dict.values()]

# 按照node在报告中出现的次数排序
# 如果node的父节点已经出现过，那么就不再单独统计，而是将其涉及的doc_keys传递给父节点，并跳过当前node
# 如果node的父节点有多个被添加进统计中，则选择tree_level最近的父节点，比如当前是level4，则选择level3的父节点
aggregrated_nodes, aggregrated_dockey_sets = [], []
for node, key_set in sorted(zip(nodes, dockey_sets), key=lambda x: len(x[1]), reverse=True):
    is_parent_exist = False
    for parent_node in sorted(node.all_parents, key=lambda x: x.tree_level, reverse=True):
        if parent_node in aggregrated_nodes:
            idx = aggregrated_nodes.index(parent_node)
            aggregrated_dockey_sets[idx].update(key_set)
            is_parent_exist = True
            break

    if not is_parent_exist:
        aggregrated_nodes.append(node)
        aggregrated_dockey_sets.append(key_set)

nodes, dockey_sets = zip(*sorted(zip(aggregrated_nodes, aggregrated_dockey_sets), key=lambda x: len(x[1]), reverse=True))
assert len(nodes) == len(dockey_sets)

In [10]:
# 仅保留inclusive相关的radlex节点
print(len(node_dockey_dict))

# 按照node在报告中出现的次数排序
# 如果node的父节点已经出现过，那么就不再单独统计，而是将其涉及的doc_keys传递给父节点，并跳过当前node
# 如果node的父节点有多个被添加进统计中，则选择tree_level最近的父节点，比如当前是level4，则选择level3的父节点
# 具体见 5_radlex_ontology.py 合并数据
print(len(nodes))

3279
1278


## Output all nodes

In [11]:
target_dockey_set = set([doc_key for doc_keys in dockey_sets for doc_key in doc_keys])
print(f"Num of total involved report sections: {len(target_dockey_set)}")


def ds_generator(target_section_name, target_data_split):
    for doc_key in target_dockey_set:
        data_split, imgDs_row_idx, section_name = doc_key.split("#")

        if data_split != target_data_split or section_name != target_section_name:
            continue

        ds_text_row_idx = doc_key_map[section_name][data_split][imgDs_row_idx]
        ds_text_data_row = ds_text[ds_text_row_idx]
        # ds_text_data_row: dict_keys(['doc_key', 'sent_toks', 'tok_char_indices', 'sents', 'sent_char_indices', 'split_sents', 'sent_idx_split_idx', 'split_sent_toks', 'split_tok_char_indices', 'radlex', 'cxrgraph_ent', 'cxrgraph_attr', 'cxrgraph_rel', 'radcoref'])

        if ds_text_data_row["sents"] == []:
            continue

        # 不需要任何过滤 radlex 节点。使用所有内容
        output_data_row = {
            "doc_key": doc_key,
            "sents": ds_text_data_row["sents"],
            "sent_toks": ds_text_data_row["sent_toks"],
            "tok_char_indices": ds_text_data_row["tok_char_indices"],
            "split_sents": ds_text_data_row["split_sents"],
            "split_sent_toks": ds_text_data_row["split_sent_toks"],
            "sent_idx_split_idx": ds_text_data_row["sent_idx_split_idx"],
            "radlex": ds_text_data_row["radlex"],
            "cxrgraph_ent": ds_text_data_row["cxrgraph_ent"],
            "cxrgraph_attr": ds_text_data_row["cxrgraph_attr"],
            "cxrgraph_rel": ds_text_data_row["cxrgraph_rel"],
            "radcoref": ds_text_data_row["radcoref"],
        }

        assert len(output_data_row["split_sents"]) != 0

        yield output_data_row

Num of total involved report sections: 733093


In [12]:
output_dir = "/home/yuxiang/liao/workspace/arrg_preprocessing/outputs/interpret_text/all"
os.makedirs(output_dir, exist_ok=True)


for section_name in ["findings", "impression"]:
    ds_dict = {}
    for data_split in ["train", "validation", "test"]:
        ds_dict[data_split] = Dataset.from_generator(ds_generator, gen_kwargs={"target_section_name": section_name, "target_data_split": data_split})
    dataset_dict_final = DatasetDict(ds_dict)

    output_path = os.path.join(output_dir, f"interpret_text_{section_name}")
    dataset_dict_final.save_to_disk(output_path)

Generating train split: 343738 examples [13:03, 438.94 examples/s]
Generating train split: 8825 examples [00:20, 432.38 examples/s]
Generating train split: 2692 examples [00:09, 291.89 examples/s]
Saving the dataset (6/6 shards): 100%|██████████| 343738/343738 [00:01<00:00, 228124.91 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 8825/8825 [00:00<00:00, 223079.76 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 2692/2692 [00:00<00:00, 136158.34 examples/s]
Generating train split: 365565 examples [12:17, 495.45 examples/s] 
Generating train split: 9308 examples [00:18, 495.19 examples/s] 
Generating train split: 2965 examples [00:05, 578.91 examples/s] 
Saving the dataset (6/6 shards): 100%|██████████| 365565/365565 [00:01<00:00, 251863.07 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 9308/9308 [00:00<00:00, 211098.64 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 2965/2965 [00:00<00:00, 190393.33 examples/s]


## Focus on radlex node: "effusion"

重新构造一个文本数据子集，
只保留目标split_sent，dockey，cxrgraphs，radlex

radcoref用于协助句子选取

现在的主要问题是需要把radcoref用上。

比如我用radlex，那么这几个radlex就视为coref。但实际上不是，因为左右有区别。
但目前的任务暂时用不上，因为我们不区分左右。只考虑effusion。

In [None]:
target_node_idx = nodes.index("http://radlex.org/RID/RID4872")
target_node = nodes[target_node_idx]
target_dockey_set = dockey_sets[target_node_idx]
print(f"Target node: {target_node.class_name} ({len(target_dockey_set)} report sections)")


def ds_generator(target_section_name, target_data_split, target_radlex_node_id):
    for doc_key in target_dockey_set:
        data_split, imgDs_row_idx, section_name = doc_key.split("#")

        if data_split != target_data_split or section_name != target_section_name:
            continue

        ds_text_row_idx = doc_key_map[section_name][data_split][imgDs_row_idx]
        ds_text_data_row = ds_text[ds_text_row_idx]

        output_data_row = {
            "doc_key": doc_key,
            "split_sents": [],
            "split_sent_toks": [],
            "sent_idx_split_idx": [],
            "radlex": [],
            "cxrgraph_ent": [],
            "cxrgraph_attr": [],
            "cxrgraph_rel": [],
        }

        for sent_idx, sent_radlex in enumerate(ds_text_data_row["radlex"]):
            is_contain_target_radlex = False

            # 判断这个句子是否包含目标radlex节点
            for radlex_item in sent_radlex:
                curr_radlex_node = radlex_nodes_dict[radlex_item["radlex_id"]]
                if target_radlex_node_id == curr_radlex_node.class_id or target_radlex_node_id in curr_radlex_node.all_parents:
                    is_contain_target_radlex = True
                    break

            # 如果句子包含目标radlex节点，那么就把这个句子的相关内容，比如radlex， cxrgraph 加入到数据集中
            # radcoref暂时不需要，因为我们只关注一个radlex节点，不需要判断句子之间的关系。在img2sent时，默认生成所有句子
            if is_contain_target_radlex:
                output_data_row["split_sents"].append(ds_text_data_row["split_sents"][sent_idx])
                output_data_row["split_sent_toks"].append(ds_text_data_row["split_sent_toks"][sent_idx])
                output_data_row["sent_idx_split_idx"].append(ds_text_data_row["sent_idx_split_idx"][sent_idx])
                output_data_row["radlex"].append(ds_text_data_row["radlex"][sent_idx])
                output_data_row["cxrgraph_ent"].append(ds_text_data_row["cxrgraph_ent"][sent_idx])
                output_data_row["cxrgraph_attr"].append(ds_text_data_row["cxrgraph_attr"][sent_idx])
                output_data_row["cxrgraph_rel"].append(ds_text_data_row["cxrgraph_rel"][sent_idx])

        assert len(output_data_row["split_sents"]) != 0

        yield output_data_row

In [None]:
test_impr_ds = Dataset.from_generator(ds_generator, gen_kwargs={"target_section_name": "impression", "target_data_split": "test", "target_radlex_node_id": target_node.class_id})
test_impr_ds

In [None]:
output_dir = "/home/yuxiang/liao/workspace/arrg_preprocessing/outputs/interpret_text/effusion"
os.makedirs(output_dir, exist_ok=True)


for section_name in ["findings", "impression"]:
    ds_dict = {}
    for data_split in ["train", "validation", "test"]:
        ds_dict[data_split] = Dataset.from_generator(ds_generator, gen_kwargs={"target_section_name": section_name, "target_data_split": data_split, "target_radlex_node_id": target_node.class_id})
    dataset_dict_final = DatasetDict(ds_dict)

    output_path = os.path.join(output_dir, f"interpret_text_{section_name}_effusion")
    dataset_dict_final.save_to_disk(output_path)