# Init

In [1]:
import datasets
from datasets import load_dataset, Sequence, Image, DatasetDict, concatenate_datasets, Dataset
import os
import json
from tqdm import tqdm
import re
import copy
import pandas as pd
import numpy as np
from typing import Union, List
import ast
import linecache
from collections import defaultdict, Counter

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_jsonline_from_file(file_path, line_idx):
    line = linecache.getline(file_path, line_idx + 1)
    return json.loads(line.strip()) if line else None

In [3]:
temp_dir = "/home/yuxiang/liao/workspace/arrg_preprocessing/outputs/interpret_sents/combined_results"


def save_to_temp(ds, version):
    temp_path = os.path.join(temp_dir, f"temp_v{version}")
    ds.save_to_disk(temp_path)
    return temp_path

# Load spacy results for reports

In [None]:
report_file = "/home/yuxiang/liao/workspace/arrg_preprocessing/outputs/interpret_reports/raw_reports.json"
with open(report_file, "r") as f:
    print(next(f))
    print(next(f))

In [None]:
new_ds = Dataset.from_json(report_file)

In [None]:
new_ds

# Load llm-sent-gen results

In [None]:
llm_file_dir = "/home/yuxiang/liao/workspace/arrg_preprocessing/outputs/interpret_sents/llm_split_sents"

with open(os.path.join(llm_file_dir, "llm_split_sents_1_of_3.json"), "r") as f:
    print(next(f))
    print(next(f))

In [None]:
doc_map = defaultdict(list)

for file_idx in range(1, 4):
    target_file_path = os.path.join(llm_file_dir, f"llm_split_sents_{file_idx}_of_3.json")
    with open(target_file_path, "r") as f:
        for line_idx, line in enumerate(tqdm(f)):
            doc = json.loads(line.strip())
            doc_map[doc["doc_key"]].append({"doc_key": doc["doc_key"], "split_sent_idx": int(doc["sent_idx"]), "file_path": target_file_path, "line_idx": line_idx})

In [None]:
def update_dataset(element):
    doc_key = element["doc_key"]

    sorted_doc_info_list = sorted(doc_map[doc_key], key=lambda x: x["split_sent_idx"])

    element["split_sents"] = []
    element["sent_idx_split_idx"] = []
    for info_dict in sorted_doc_info_list:
        # file_doc = {"doc_key":"train#0#impression","sent_idx":1,"original_sent":"STABLE SMALL LEFT PLEURAL EFFUSION.","split_sents":["Stable small left pleural effusion."]}
        file_doc = load_jsonline_from_file(info_dict["file_path"], info_dict["line_idx"])
        assert element["doc_key"] == file_doc["doc_key"]
        assert element["sents"][file_doc["sent_idx"]] == file_doc["original_sent"]

        for split_idx, split_sent in enumerate(file_doc["split_sents"]):
            if split_sent.strip() == "":
                continue
            element["split_sents"].append(split_sent)
            element["sent_idx_split_idx"].append((file_doc["sent_idx"], split_idx))

    return element


# temp_ds = new_ds.select(range(10))
new_ds = new_ds.map(update_dataset)

In [None]:
new_ds[37]

In [None]:
temp_path = save_to_temp(new_ds, version=1)
temp_path

# Load spacy results for sentences

In [None]:
temp_path = "/home/yuxiang/liao/workspace/arrg_preprocessing/outputs/interpret_sents/combined_results/temp_v1"
new_ds = Dataset.load_from_disk(temp_path)
new_ds

In [None]:
spacy_sent_file = "/home/yuxiang/liao/workspace/arrg_preprocessing/outputs/interpret_sents/raw/raw_sents.json"
with open(spacy_sent_file, "r") as f:
    print(next(f))
    print(next(f))

In [None]:
doc_map = defaultdict(list)

with open(spacy_sent_file, "r") as f:
    for line_idx, line in enumerate(tqdm(f)):
        doc = json.loads(line.strip())
        data_split, row_idx, section_name, orig_sent_idx, split_sent_idx = doc["doc_key"].split("#")
        doc_key = f"{data_split}#{row_idx}#{section_name}"

        doc_map[doc_key].append({"doc_key": doc_key, "sent_idx": int(orig_sent_idx), "split_sent_idx": int(split_sent_idx), "file_path": spacy_sent_file, "line_idx": line_idx})

In [None]:
def update_dataset(element):
    element["split_sent_toks"] = [[] for _ in range(len(element["split_sents"]))]
    element["split_tok_char_indices"] = [[] for _ in range(len(element["split_sents"]))]
    if len(element["split_sents"]) == 0:
        return element

    sorted_doc_info_list = sorted(doc_map[element["doc_key"]], key=lambda x: (x["sent_idx"], x["split_sent_idx"]))
    for info_dict in sorted_doc_info_list:
        # file_doc = {"doc_key": "train#0#impression#0#1", "split_sent_text": "Decreased bibasilar parenchymal opacities are now minimal.", "split_sent_toks": [["Decreased", "bibasilar", "parenchymal", "opacities", "are", "now", "minimal", "."]], "tok_char_indices": [[[0, 9], [10, 19], [20, 31], [32, 41], [42, 45], [46, 49], [50, 57], [57, 58]]]}
        file_doc = load_jsonline_from_file(info_dict["file_path"], info_dict["line_idx"])
        data_split, row_idx, section_name, orig_sent_idx, split_sent_idx = file_doc["doc_key"].split("#")
        orig_sent_idx = int(orig_sent_idx)
        split_sent_idx = int(split_sent_idx)
        assert info_dict["sent_idx"] == orig_sent_idx and info_dict["split_sent_idx"] == split_sent_idx
        _doc_key = f"{data_split}#{row_idx}#{section_name}"
        assert element["doc_key"] == _doc_key
        _idx = element["sent_idx_split_idx"].index([orig_sent_idx, split_sent_idx])
        assert element["split_sents"][_idx] == file_doc["split_sent_text"]

        assert len(file_doc["split_sent_toks"]) == 1
        assert len(file_doc["tok_char_indices"]) == 1

        element["split_sent_toks"][_idx] = file_doc["split_sent_toks"][0]
        element["split_tok_char_indices"][_idx] = file_doc["tok_char_indices"][0]

    assert len(element["split_sent_toks"]) == len(element["split_sents"])

    return element


# temp_ds = new_ds.select(range(10))
new_ds = new_ds.map(update_dataset)

In [None]:
temp_path = save_to_temp(new_ds, version=2)
temp_path

# Load radlex results

## Load radlex ontology

In [None]:
class OntologyNode:
    def __init__(self, row_idx, class_id, class_name, df_row):
        self.row_idx = row_idx
        self.class_id = class_id
        self.class_name = class_name
        self.synonyms = [] if df_row["Synonyms"] == "" else df_row["Synonyms"].split("|")
        self.df_row = df_row

        # The tree structure is maintained by the parent and children attributes. Only one level of parent-child relationship is maintained.
        self.parent = []
        self.children = []
        self.is_root = False
        self.tree_level = None

        # It's parents from all levels
        self._all_parents = []

    def add_child(self, child):
        self.children.append(child)

    def add_parent(self, parent):
        self.parent.append(parent)

    @property
    def all_parents(self):
        if self.is_root:
            return []
        elif self._all_parents:
            return self._all_parents
        else:
            for parent in self.parent:
                # 避免父节点重复
                self._all_parents = set(parent.all_parents + [parent])
                self._all_parents = list(self._all_parents)
            return self._all_parents

    def __eq__(self, other):
        if isinstance(other, OntologyNode):
            return self.class_id == other.class_id
        else:
            return self.class_id == other

    def __hash__(self):
        return hash(self.class_id)

    def __str__(self):
        return f"{self.class_id}: {self.class_name}"

    def __repr__(self):
        return self.__str__()


def set_tree_level(curr_node, tree_level):
    curr_node.tree_level = tree_level
    for child in curr_node.children:
        set_tree_level(child, tree_level + 1)
    if not curr_node.children:
        return

In [None]:
def build_radlex_tree(df_csv):
    # Build a RadLex node list
    node_list = []
    root_node = None
    for idx, row in tqdm(df_csv.iterrows(), total=df_csv.shape[0], desc="Building RadLex tree"):
        ontology_node = OntologyNode(row_idx=idx, class_id=row["Class ID"], class_name=row["Preferred Label"], df_row=row)
        if row["Preferred Label"] in row["Class ID"]:
            ontology_node.class_name = row["http://radlex.org/RID/Preferred_Name_for_Obsolete"]
        node_list.append(ontology_node)

    # Resolve the node list and build a RadLex tree
    for node in tqdm(node_list, total=len(node_list), desc="Building RadLex tree"):
        df_row = node.df_row
        parent_ids = df_row["Parents"].split("|")
        for parent_id in parent_ids:
            parent_row_indices = df_csv.loc[df_csv["Class ID"] == parent_id].index
            if not parent_row_indices.empty:
                parent_row_idx = parent_row_indices[0]
                parent_node = node_list[parent_row_idx]
                assert parent_node.class_id == parent_id
                node.add_parent(parent_node)
                parent_node.add_child(node)
            else:
                # In radlex, http://radlex.org/RID/RID0 has parent http://www.w3.org/2002/07/owl#Thing.
                # However, the RID0 is already the root node in the RadLex ontology. We can safely ignore the owl#Thing.
                root_node = node
                node.is_root = True
                node.tree_level = 0

    return node_list, root_node

In [None]:
radlex_csv_path = "/home/yuxiang/liao/resources/bioportal/radlex/RADLEX.csv"
df_radlex_csv = pd.read_csv(radlex_csv_path, keep_default_na=False)
radlex_nodes, radlex_root_node = build_radlex_tree(df_radlex_csv)
radlex_nodes_dict = {node.class_id: node for node in radlex_nodes}
print(f"Number of RadLex nodes: {len(radlex_nodes)}")

# Tracing all parents of nodes
for node in radlex_nodes:
    node.all_parents

set_tree_level(radlex_root_node, tree_level=0)

## Analyse fuzzy

In [None]:
radlex_file = "/home/yuxiang/liao/workspace/arrg_preprocessing/outputs/interpret_sents/radlex_annotate/radlex_ann.json"
with open(radlex_file, "r") as f:
    print(next(f))
    print(next(f))

In [None]:
fuzzy_match_dict = defaultdict(set)
fuzzy_match_count = Counter()

with open(radlex_file, "r") as f:
    for line_idx, line in enumerate(tqdm(f)):
        doc = json.loads(line.strip())
        data_split, row_idx, section_name, orig_sent_idx, split_sent_idx = doc["doc_key"].split("#")
        doc_key = f"{data_split}#{row_idx}#{section_name}"

        position_matches = defaultdict(list)
        for matched_info in doc["radlex"]:
            # matched_info = {"match_type": "fuzzy_lemma", "radlex_id": "http://radlex.org/RID/RID5978", "radlex_name": "parenchyma", "matched_text": "parenchymal", "char_indices": [20, 31], "tok_indices": [2, 3]}
            posi_id = "_".join(map(str, matched_info["tok_indices"]))
            position_matches[posi_id].append(matched_info)

        for matched_info in doc["radlex"]:
            posi_id = "_".join(map(str, matched_info["tok_indices"]))
            # 匹配逻辑：id = radlex_id+start+end 如果有exact match，就忽略fuzzy match。但没有考虑不同id的match情况。
            # 比如 hemithorax，即能exact match到 hemithorax，也能fuzzy match到 hemothorax
            # 我们这里仅分析某个span的所有match都是fuzzy_match
            if matched_info["match_type"] == "fuzzy_lemma" and all([i["match_type"] == "fuzzy_lemma" for i in position_matches[posi_id]]):
                fuzzy_match_dict[(matched_info["radlex_id"], matched_info["radlex_name"])].add(matched_info["matched_text"])
                fuzzy_match_count.update([(matched_info["radlex_id"], matched_info["radlex_name"])])

In [None]:
len(fuzzy_match_count)

In [None]:
for k, v in fuzzy_match_count.most_common():
    print(k[0])
    print("  ", k[1], v)
    print("  ", ", ".join(fuzzy_match_dict[k]))

## Process

In [None]:
# 根据 Analyse fuzzy的结果（507个radlex-id），人工筛选出一些不合适的，且出现频率较高的radlex_id。

invalid_radlex_ids = set(
    [
        "http://radlex.org/RID/RID38667",  # thinning
        "http://radlex.org/RID/RID5022",  # stricture
        "http://radlex.org/RID/RID9889",  # frontalis
        "http://radlex.org/RID/RID3829",  # scar
        "http://radlex.org/RID/RID5801",  # lobular
        "http://radlex.org/RID/RID5015",  # inspissation
        "http://radlex.org/RID/RID5956",  # contents
        "http://radlex.org/RID/RID5783",  # contracted
        "http://radlex.org/RID/RID5843",  # inverted
        "http://radlex.org/RID/RID28656",  # secretin
        "http://radlex.org/RID/RID35977",  # property
        "http://radlex.org/RID/RID10453",  # standing position
        "http://radlex.org/RID/RID43613",  # Clements view
        "http://radlex.org/RID/RID2198",  # unciform
        "http://radlex.org/RID/RID49605",  # training
        "http://radlex.org/RID/RID29980",  # left hemithorax
        "http://radlex.org/RID/RID29979",  # right hemithorax
        "http://radlex.org/RID/RID29981",  # upper hemithorax
        "http://radlex.org/RID/RID29986",  # left lower hemithorax
        "http://radlex.org/RID/RID29982",  # right upper hemithorax
        "http://radlex.org/RID/RID29985",  # right lower hemithorax,
        "http://radlex.org/RID/RID29983",  # left upper hemithorax
    ]
)

invalid_radlex_text_pairs = {
    "http://radlex.org/RID/RID29984": "lower hemothorax",  # it also has "lower hemithoraxes", which is a correect fuzzy match to "lower hemithorax"
}

# 对于 fuzzy 匹配到 hemithorax 和 hemothorax，我们无法判断报告中究竟指的是哪个。因此我们默认报告中的是正确的
# 由于 left/right hemithorax 是有精确匹配的，所以只会出现 hemothorax + 位置 被 fuzzy 匹配到的 hemithorax的情况
# 当出现这种 fuzzy 匹配时，我们直接忽略，因为我们默认报告中写的 hemothorax 是正确的
# （即使不正确，我们也不想要这种数据污染我们的数据集）

In [None]:
radlex_file = "/home/yuxiang/liao/workspace/arrg_preprocessing/outputs/interpret_sents/radlex_annotate/radlex_ann.json"
with open(radlex_file, "r") as f:
    print(next(f))
    print(next(f))

In [None]:
# 统计不同type下，radlex_id的出现频率。
# 在后续处理过程中，当同一个span在同一个type下匹配到多个radlex_id时，我们会选择出现频率最高的radlex_id
radlex_freq_dict = {"text": Counter(), "lower_text": Counter(), "lemma": Counter(), "fuzzy_lemma": Counter()}

with open(radlex_file, "r") as f:
    for line_idx, line in enumerate(tqdm(f)):
        doc = json.loads(line.strip())
        data_split, row_idx, section_name, orig_sent_idx, split_sent_idx = doc["doc_key"].split("#")
        doc_key = f"{data_split}#{row_idx}#{section_name}"

        for matched_info in doc["radlex"]:
            # matched_info = {"match_type": "fuzzy_lemma", "radlex_id": "http://radlex.org/RID/RID5978", "radlex_name": "parenchyma", "matched_text": "parenchymal", "char_indices": [20, 31], "tok_indices": [2, 3]}
            radlex_freq_dict[matched_info["match_type"]].update([matched_info["radlex_id"]])

In [None]:
temp_path = "/home/yuxiang/liao/workspace/arrg_preprocessing/outputs/interpret_sents/combined_results/temp_v2"
new_ds = Dataset.load_from_disk(temp_path)
new_ds

In [None]:
doc_map = defaultdict(list)

with open(radlex_file, "r") as f:
    for line_idx, line in enumerate(tqdm(f)):
        doc = json.loads(line.strip())
        data_split, row_idx, section_name, orig_sent_idx, split_sent_idx = doc["doc_key"].split("#")
        doc_key = f"{data_split}#{row_idx}#{section_name}"

        doc_map[doc_key].append({"doc_key": doc_key, "sent_idx": int(orig_sent_idx), "split_sent_idx": int(split_sent_idx), "file_path": radlex_file, "line_idx": line_idx})

In [None]:
def filter_by_priority(element):

    element["radlex"] = [[] for _ in range(len(element["split_sents"]))]
    if len(element["split_sents"]) == 0:
        return element

    sorted_doc_info_list = sorted(doc_map[element["doc_key"]], key=lambda x: (x["sent_idx"], x["split_sent_idx"]))
    for info_dict in sorted_doc_info_list:
        # file_doc = {"doc_key": "train#0#impression#0#1", "sent_text": "Decreased bibasilar parenchymal opacities are now minimal.", "radlex": [{"match_type": "lemma", "radlex_id": "http://radlex.org/RID/RID5733", "radlex_name": "decreasing", "matched_text": "Decreased", "char_indices": [0, 9], "tok_indices": [0, 1]}, ...]}
        file_doc = load_jsonline_from_file(info_dict["file_path"], info_dict["line_idx"])
        data_split, row_idx, section_name, orig_sent_idx, split_sent_idx = file_doc["doc_key"].split("#")
        orig_sent_idx = int(orig_sent_idx)
        split_sent_idx = int(split_sent_idx)
        assert info_dict["sent_idx"] == orig_sent_idx and info_dict["split_sent_idx"] == split_sent_idx
        _doc_key = f"{data_split}#{row_idx}#{section_name}"
        assert element["doc_key"] == _doc_key
        _idx = element["sent_idx_split_idx"].index([orig_sent_idx, split_sent_idx])
        assert element["split_sents"][_idx] == file_doc["sent_text"]

        position_matches = defaultdict(list)
        for matched_info in file_doc["radlex"]:
            # matched_info = {"match_type": "fuzzy_lemma", "radlex_id": "http://radlex.org/RID/RID5978", "radlex_name": "parenchyma", "matched_text": "parenchymal", "char_indices": [20, 31], "tok_indices": [2, 3]}
            posi_id = (matched_info["tok_indices"][0], matched_info["tok_indices"][1])
            position_matches[posi_id].append(matched_info)

        sorted_position_matches = sorted(position_matches.items(), key=lambda x: x[0])
        for _, span_matches in sorted_position_matches:
            # 按优先级找到第一个匹配的类型，然后将其加入到radlex中，如果找到后就break，然后进行下一个span（位置）的过滤
            for target_type in ["text", "lower_text", "lemma", "fuzzy_lemma"]:
                # 同一个类型可能会有多个匹配，比如：
                # [{'match_type': 'lower_text', 'radlex_id': 'http://radlex.org/RID/RID39433', 'radlex_name': 'arterial phase (liver)', 'matched_text': 'AP', 'char_indices': [0, 2], 'tok_indices': [0, 1]},
                # {'match_type': 'lower_text', 'radlex_id': 'http://radlex.org/RID/RID11080', 'radlex_name': 'arterial phase', 'matched_text': 'AP', 'char_indices': [0, 2], 'tok_indices': [0, 1]}]
                target_matches = [match for match in span_matches if match["match_type"] == target_type]

                # 过滤掉一些无效的radlex_id
                filtered_matches = []
                for matched_span in target_matches:
                    if target_type == "fuzzy_lemma":
                        if matched_span["radlex_id"] in invalid_radlex_ids:
                            continue
                        if matched_span["radlex_id"] in invalid_radlex_text_pairs and matched_span["matched_text"] == invalid_radlex_text_pairs[matched_span["radlex_id"]]:
                            continue
                    filtered_matches.append(matched_span)

                # 一个span最多只选择一个radlex_id
                # 对于多个匹配，我们根据radlex_id的频率来选择，选择在数据集中出现频率最高的radlex_id
                if filtered_matches:
                    freqs = [radlex_freq_dict[target_type].get(match["radlex_id"]) for match in filtered_matches]
                    matched_span = filtered_matches[freqs.index(max(freqs))]
                    element["radlex"][_idx].append(matched_span)
                    break

    return element


# temp_ds = new_ds.select(range(10))
new_ds = new_ds.map(filter_by_priority)

In [None]:
new_ds[37]["split_sents"]

In [None]:
new_ds[37]["radlex"][1]

In [None]:
temp_path = save_to_temp(new_ds, version=1)
temp_path

# Load cxrgraph results

In [None]:
temp_path = "/home/yuxiang/liao/workspace/arrg_preprocessing/outputs/interpret_sents/combined_results/temp_v1"
new_ds = Dataset.load_from_disk(temp_path)
new_ds

In [None]:
cxrgraph_file = "/home/yuxiang/liao/workspace/arrg_preprocessing/outputs/interpret_sents/cxrgraph/inference.json"

with open(cxrgraph_file, "r") as f:
    print(next(f))
    print(next(f))
    print(next(f))
    print(next(f))
    print(next(f))

In [None]:
doc_map = defaultdict(list)

with open(cxrgraph_file, "r") as f:
    for line_idx, line in enumerate(tqdm(f)):
        doc = json.loads(line.strip())
        data_split, row_idx, section_name, orig_sent_idx, split_sent_idx = doc["doc_key"].split("#")
        doc_key = f"{data_split}#{row_idx}#{section_name}"

        doc_map[doc_key].append({"doc_key": doc_key, "sent_idx": int(orig_sent_idx), "split_sent_idx": int(split_sent_idx), "file_path": cxrgraph_file, "line_idx": line_idx})

In [None]:
def is_all_number_or_symbols(lst):
    return all(all(char in "!@#$%^&*()-_=+[]{};:'\",.<>?/|\\`~" or char.isdigit() for char in item) for item in lst)


def update_dataset(element):
    element["cxrgraph_ent"] = [[] for _ in range(len(element["split_sents"]))]
    element["cxrgraph_attr"] = [[] for _ in range(len(element["split_sents"]))]
    element["cxrgraph_rel"] = [[] for _ in range(len(element["split_sents"]))]
    if len(element["split_sents"]) == 0:
        return element

    sorted_doc_info_list = sorted(doc_map[element["doc_key"]], key=lambda x: (x["sent_idx"], x["split_sent_idx"]))
    for info_dict in sorted_doc_info_list:
        # file_doc = {"doc_key": "train#0#impression#0#1", "sentences": [["Decreased", "bibasilar", "parenchymal", "opacities", "are", "now", "minimal", "."]],
        # "pred_ner": [[[0, 0, "Observation-Present"], [1, 1, "Anatomy"], [2, 2, "Anatomy"], [3, 3, "Observation-Present"], [6, 6, "Observation-Present"], [7, 7, "Observation-Present"]]],
        # "pred_attr": [[[0, 0, "NA", "NA", "NA"]]],
        # "pred_rel": [[[0, 0, 3, 3, "modify"], [2, 2, 1, 1, "part_of"], [3, 3, 2, 2, "located_at"], [6, 6, 7, 7, "modify"], [6, 6, 3, 3, "modify"], [7, 7, 3, 3, "modify"]]]}
        file_doc = load_jsonline_from_file(info_dict["file_path"], info_dict["line_idx"])
        data_split, row_idx, section_name, orig_sent_idx, split_sent_idx = file_doc["doc_key"].split("#")
        orig_sent_idx = int(orig_sent_idx)
        split_sent_idx = int(split_sent_idx)
        assert info_dict["sent_idx"] == orig_sent_idx and info_dict["split_sent_idx"] == split_sent_idx
        _doc_key = f"{data_split}#{row_idx}#{section_name}"
        assert element["doc_key"] == _doc_key
        _idx = element["sent_idx_split_idx"].index([orig_sent_idx, split_sent_idx])
        assert element["split_sent_toks"][_idx] == file_doc["sentences"][0]

        assert len(file_doc["sentences"]) == 1
        sent_text = file_doc["sentences"][0]
        assert sent_text == element["split_sent_toks"][_idx]

        for ner in file_doc["pred_ner"][0]:
            # ner = [0, 0, "Observation-Present"]
            tok_start = ner[0]
            tok_end = ner[1] + 1
            ent_toks = sent_text[tok_start:tok_end]
            if not is_all_number_or_symbols(ent_toks):
                ent_type = ner[2]
                element["cxrgraph_ent"][_idx].append({"tok_indices": [tok_start, tok_end], "ent_toks": ent_toks, "ent_type": ent_type})

        for attr in file_doc["pred_attr"][0]:
            # attr = [0, 0, "NA", "NA", "NA"]
            tok_start = attr[0]
            tok_end = attr[1] + 1
            ent_toks = sent_text[tok_start:tok_end]
            if not is_all_number_or_symbols(ent_toks):
                attr_normality = attr[2]
                attr_action = attr[3]
                attr_change = attr[4]
                if attr_normality != "NA" or attr_action != "NA" or attr_change != "NA":
                    element["cxrgraph_attr"][_idx].append({"tok_indices": [tok_start, tok_end], "ent_toks": ent_toks, "attr_normality": attr_normality, "attr_action": attr_action, "attr_change": attr_change})

        for rel in file_doc["pred_rel"][0]:
            # rel = [0, 0, 3, 3, "modify"]
            subj_tok_start = rel[0]
            subj_tok_end = rel[1] + 1
            subj_toks = sent_text[subj_tok_start:subj_tok_end]
            obj_tok_start = rel[2]
            obj_tok_end = rel[3] + 1
            obj_toks = sent_text[obj_tok_start:obj_tok_end]
            if not is_all_number_or_symbols(subj_toks) and not is_all_number_or_symbols(obj_toks):
                rel_type = rel[4]
                element["cxrgraph_rel"][_idx].append({"subj_tok_indices": [subj_tok_start, subj_tok_end], "subj_toks": subj_toks, "obj_tok_indices": [obj_tok_start, obj_tok_end], "obj_toks": obj_toks, "rel_type": rel_type})

    return element


# temp_ds = new_ds.select(range(100))
new_ds = new_ds.map(update_dataset)

In [None]:
temp_path = save_to_temp(new_ds, version=2)
temp_path

# Load radcoref results

In [None]:
temp_path = "/home/yuxiang/liao/workspace/arrg_preprocessing/outputs/interpret_sents/combined_results/temp_v2"
new_ds = Dataset.load_from_disk(temp_path)
new_ds

In [None]:
cxrgraph_file = "/home/yuxiang/liao/workspace/arrg_preprocessing/outputs/interpret_sents/radcoref/coref_inference.json"

with open(cxrgraph_file, "r") as f:
    print(next(f))
    next(f)
    next(f)
    next(f)
    next(f)
    print(next(f))

In [None]:
doc_map = {}

with open(cxrgraph_file, "r") as f:
    for line_idx, line in enumerate(tqdm(f)):
        doc = json.loads(line.strip())
        doc_key = doc["doc_key"]

        doc_map[doc_key] = {"doc_key": doc_key, "file_path": cxrgraph_file, "line_idx": line_idx}

In [None]:
def update_dataset(element):
    element["radcoref"] = []
    if len(element["split_sents"]) == 0:
        return element

    # file_doc = {"doc_key": "train#0#impression", "split_sent_toks": [["12/15/2004", "at", "1830", "accession", ":", "6772908", ":"], ["Stable", "right", "pigtail", "pleural", "catheter", "."], ["ETT", "is", "stable", "."], ...],
    # "coref_clusters": [[[0, 0, "12/15/2004"]], [[13, 13, "ETT"], [143, 144, "Stable ETT"]], [[17, 19, "The NG tube"], [146, 147, "NG tube"]], [[30, 33, "Right upper lobe opacity"], [35, 36, "The opacity"], [92, 95, "the right upper lobe"], [116, 116, "Opacity"], [123, 124, "The opacity"], [162, 165, "the right upper lobe"], [167, 168, "The opacity"]], [[49, 55, "Fracture of the posterior right rib ."], [49, 54, "Fracture of the posterior right rib"], [56, 57, "The fracture"]], [[76, 76, "CT"], [88, 88, "CT"]]]}
    info_dict = doc_map[element["doc_key"]]
    file_doc = load_jsonline_from_file(info_dict["file_path"], info_dict["line_idx"])

    assert element["doc_key"] == file_doc["doc_key"]
    # 注意，我们在用spacy处理llm-sent时，会手动过滤不合格的sent。所以被过滤的句子对应的split_sent_toks会是空list
    for test_idx, test_toks in enumerate(file_doc["split_sent_toks"]):
        if test_toks != []:
            assert element["split_sent_toks"][test_idx] == test_toks

    toks = []
    tok2sent_map = []
    for sent_idx, sent in enumerate(element["split_sent_toks"]):
        for tok in sent:
            toks.append(tok)
            tok2sent_map.append(sent_idx)

    for cluster in file_doc["coref_clusters"]:
        if len(cluster) > 1:
            out_cluster = []
            for mention in cluster:
                # mention = [0, 0, "12/15/2004"]
                tok_start = mention[0]
                tok_end = mention[1] + 1

                mention_toks = toks[tok_start:tok_end]
                assert mention[2] == " ".join(mention_toks)

                target_sent_idx = tok2sent_map[tok_start]
                out_cluster.append({"tok_indices": [tok_start, tok_end], "mention_toks": mention_toks, "target_sent_idx": target_sent_idx})
            element["radcoref"].append(out_cluster)

    return element


# temp_ds = new_ds.select(range(100))
new_ds = new_ds.map(update_dataset)

In [None]:
temp_path = save_to_temp(new_ds, version=1)
temp_path

# Final process

In [4]:
temp_path = "/home/yuxiang/liao/workspace/arrg_preprocessing/outputs/interpret_sents/combined_results/temp_v1"
new_ds = Dataset.load_from_disk(temp_path)
new_ds

Dataset({
    features: ['doc_key', 'sent_toks', 'tok_char_indices', 'sents', 'sent_char_indices', 'split_sents', 'sent_idx_split_idx', 'split_sent_toks', 'split_tok_char_indices', 'radlex', 'cxrgraph_ent', 'cxrgraph_attr', 'cxrgraph_rel', 'radcoref'],
    num_rows: 1136366
})

In [None]:
for data in tqdm(new_ds):
    if any([len(sent) == 0 for sent in data["split_sent_toks"]]):
        print(data["doc_key"])
        print(data["split_sents"])
        print(data["split_sent_toks"])
        break

# 确认 split_sents 中的invalid sents 会被忽略，导致 split_sent_toks 中对应index 的位置是空list
# 这个情况不需要处理，因为radcoref中的每个mention都明确的指明了split_sent_index。

In [None]:
new_ds = new_ds.save_to_disk("/home/yuxiang/liao/workspace/arrg_preprocessing/outputs/interpret_sents/combined_results/combine_final")

Saving the dataset (13/13 shards): 100%|██████████| 1136366/1136366 [00:16<00:00, 67942.11 examples/s]
