上一个方法使用graph repr并不能让model有效的关注到图像中的重点。可能是graph太复杂了，杂乱的信息太多，因为我们把所有的ent和rel都放了进去。

这里尝试提前筛选最重要的部分，避免在graph中保留太多信息

# Get root nodes from CXRGraph

In [1]:
import bisect
from collections import Counter, defaultdict

from datasets import load_from_disk
from tqdm import tqdm
import os

import spacy

nlp = spacy.load("en_core_web_sm")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Entity:
    def __init__(self, start, end, label, sent_id, tok_list=None, tok_str=None):
        self.id = None
        self.tok_indices = [start, end]
        self.label = label

        self.sent_id = sent_id
        if tok_list:
            self.tok_list = tok_list
            self.tok_str = " ".join(tok_list) if not tok_str else tok_str
        elif tok_str:
            self.tok_str = tok_str
            self.tok_list = tok_str.split(" ")

        if "Observation" in label:
            self.label_type = "OBS"
        elif "Anatomy" == label:
            self.label_type = "ANAT"
        else:
            self.label_type = "LOCATT"

        self.attr_normal = "NA"
        self.attr_action = "NA"
        self.attr_change = "NA"

        self.chain_info = {
            "modify": {"from": [], "to": []},
            "part_of": {"from": [], "to": []},
            "located_at": {"from": [], "to": []},
            "suggestive_of": {"from": [], "to": []},
        }

    def __repr__(self) -> str:
        # return f"{self.tok_str} {self.tok_indices}: {self.label}, {self.attr_normal, self.attr_action, self.attr_change}"
        return f"{self.tok_str}"

    def __str__(self) -> str:
        return self.__repr__()

    def __eq__(self, other):
        if isinstance(other, Entity):
            return self.tok_indices == other.tok_indices
        else:
            return other == self.tok_indices

    def __hash__(self):
        return hash(str(self.tok_indices))


class Relation:
    def __init__(self, subj_ent, obj_ent, label):
        self.label = label
        self.subj_ent = subj_ent
        self.obj_ent = obj_ent

    def __repr__(self) -> str:
        return f"{self.subj_ent.tok_str} {self.label} {self.obj_ent.tok_str}"

    def __str__(self) -> str:
        return self.__repr__()


class LinkedGraph:
    def __init__(self, ents):
        self.id = None
        self.ents = sorted(ents, key=lambda x: x.tok_indices[0])
        self.rels = []
        self.sent_id = ents[0].sent_id

        assert len(set([i.sent_id for i in ents])) == 1

    def get_involved_rels(self, rel_list):
        target_rels = []
        in_used_ents = set()
        for rel in rel_list:
            if rel.subj_ent in self.ents and rel.obj_ent in self.ents:
                target_rels.append(rel)
                in_used_ents.update([rel.subj_ent, rel.obj_ent])
        self.rels = target_rels

    def __repr__(self) -> str:
        return f"{[i.tok_str for i in self.ents]}"

    def __str__(self) -> str:
        return self.__repr__()

In [3]:
def search_linked_ents(curr_ent, visited, group):
    visited.add(curr_ent)
    group.append(curr_ent)
    neighbors = [ent for nested_dict in curr_ent.chain_info.values() for adjacent_ents in nested_dict.values() for ent in adjacent_ents]
    for next_ent in neighbors:
        if next_ent not in visited:
            search_linked_ents(next_ent, visited, group)


def max_coverage_spans(spans):
    if not spans:
        return [], [], 0

    # 按结束时间升序排序
    sorted_spans = sorted(spans, key=lambda x: x[1])
    n = len(sorted_spans)
    starts = [s[0] for s in sorted_spans]
    ends = [s[1] for s in sorted_spans]
    lengths = [e - s for s, e in sorted_spans]

    # 预处理j_values数组，记录每个i对应的最大的j，使得 ends[j] <= starts[i]
    j_values = []
    for i in range(n):
        start_i = starts[i]
        j = bisect.bisect_right(ends, start_i) - 1  # 二分查找, 找到第一个`大于`start_i的位置
        j_values.append(j)

    # 构建 dp 数组，其中 dp[i] 表示前 i+1 个 span 的最大总覆盖率。通过比较包含当前 span 和不包含当前 span 的情况，确定最优解。
    # dp记录了选中下一个span之后的总覆盖率
    dp = [0] * n
    dp[0] = lengths[0]
    for i in range(1, n):
        j = j_values[i]
        current = lengths[i] + (dp[j] if j >= 0 else 0)
        dp[i] = max(dp[i - 1], current)

    # 回溯找出选中的span。从最后一个span开始，如果当前span被选中，则跳到j_values[i]对应的span
    # 当dp发生变化时，说明
    selected_indices = []
    i = n - 1
    while i >= 0:
        if i == 0:
            if dp[i] == lengths[i]:
                selected_indices.append(i)
            break
        if dp[i] > dp[i - 1]:
            selected_indices.append(i)
            i = j_values[i]
        else:
            i -= 1

    selected_indices.reverse()
    selected_spans = [sorted_spans[i] for i in selected_indices]
    total_coverage = dp[-1]

    return selected_indices, selected_spans, total_coverage


def resolve_ent_rel(split_sent_idx, cxrgraph_ent_lst, cxrgraph_rel_lst, cxrgraph_attr_lst, radlex_lst):
    # Entity and Relation 使用的是cxrgraph的结果，radlex则是用来 normalize cxrgraph的ent
    ent_list = []
    rel_list = []
    for ent in cxrgraph_ent_lst:
        ent = Entity(start=ent["tok_indices"][0], end=ent["tok_indices"][1], label=ent["ent_type"], tok_list=ent["ent_toks"], sent_id=split_sent_idx)
        ent_list.append(ent)
    for attr in cxrgraph_attr_lst:
        ent = ent_list[ent_list.index(attr["tok_indices"])]
        ent.attr_normal = attr["attr_normality"]
        ent.attr_action = attr["attr_action"]
        ent.attr_change = attr["attr_change"]
    for rel in cxrgraph_rel_lst:
        subj_ent = ent_list[ent_list.index(rel["subj_tok_indices"])]
        obj_ent = ent_list[ent_list.index(rel["obj_tok_indices"])]
        label = rel["rel_type"]
        if obj_ent not in subj_ent.chain_info[label]["to"]:
            subj_ent.chain_info[label]["to"].append(obj_ent)
        if subj_ent not in obj_ent.chain_info[label]["from"]:
            obj_ent.chain_info[label]["from"].append(subj_ent)
        rel_list.append(Relation(subj_ent, obj_ent, label))

    # Set ent id
    for ent_idx, ent in enumerate(sorted(ent_list, key=lambda x: x.tok_indices[0])):
        ent.id = f"E{ent_idx}"

    # 选择覆盖率最大的radlex子集
    radlex_ent_indices = [node["tok_indices"] for node in radlex_lst]
    selected_idx_list, _, _ = max_coverage_spans(radlex_ent_indices)

    # 用radlex的ent替换cxrgraph的ent
    for radlex_idx in selected_idx_list:
        radlex_ent = radlex_lst[radlex_idx]
        merged_cxrgraph_ents = []
        for cxrgraph_ent in ent_list:
            # 如果cxrgrpah被radlex包含，那么就加入候选集等待替换；如果cxrgraph和radlex有交集，那么就跳过这个radlex
            pos_ab = check_span_relation(cxrgraph_ent.tok_indices, radlex_ent["tok_indices"])
            if pos_ab in ["equal", "inside"]:
                merged_cxrgraph_ents.append(cxrgraph_ent)
            elif pos_ab == "overlap":
                break
            else:
                continue

        # 如果merged_cxrgraph_ents不为空，那么就用radlex替换候选集的cxrgraph ent
        if merged_cxrgraph_ents:
            inherited_label = get_label_inheritance(merged_cxrgraph_ents)
            inherited_attr_dict = get_attr_inheritance(merged_cxrgraph_ents)

            new_ent = Entity(start=radlex_ent["tok_indices"][0], end=radlex_ent["tok_indices"][1], label=inherited_label, tok_str=radlex_ent["radlex_name"], sent_id=split_sent_idx)
            new_ent.attr_normal = inherited_attr_dict["normality"]
            new_ent.attr_action = inherited_attr_dict["action"]
            new_ent.attr_change = inherited_attr_dict["change"]
            new_ent.id = radlex_ent["radlex_id"]

            # inherit chain info
            for cxrgraph_ent in merged_cxrgraph_ents:
                for rel_type, from_to_dict in cxrgraph_ent.chain_info.items():
                    # 把merged_cxrgraph_ents的from和to的关系都继承过来，如果是内部ents之间指向关系，那么就跳过
                    for key, value_lst in from_to_dict.items():
                        for value in value_lst:
                            if value not in merged_cxrgraph_ents:
                                if value not in new_ent.chain_info[rel_type][key]:
                                    new_ent.chain_info[rel_type][key].append(value)

            # replace from ent_list
            ent_list.append(new_ent)
            for cxrgraph_ent in merged_cxrgraph_ents:
                ent_list.remove(cxrgraph_ent)

            # replace from rel_list
            # pleural_effusion 应该把 pleural 和 effusion 都替换掉。在rel中则包括：
            #   opacifications suggestive_of effusions
            #   bilateral modify pleural
            #   effusions located_at pleural
            rel_objs_tobe_removed = []
            for rel in rel_list:
                if rel.subj_ent in merged_cxrgraph_ents and rel.obj_ent in merged_cxrgraph_ents:
                    # 关于 from 和 to 的关系链，在新的ent中已经继承了，所以这里不需要处理
                    rel_objs_tobe_removed.append(rel)
                elif rel.subj_ent in merged_cxrgraph_ents:
                    # subj need to be replaced
                    if rel.subj_ent in rel.obj_ent.chain_info[rel.label]["from"]:
                        rel.obj_ent.chain_info[rel.label]["from"].remove(rel.subj_ent)
                    if new_ent not in rel.obj_ent.chain_info[rel.label]["from"]:
                        rel.obj_ent.chain_info[rel.label]["from"].append(new_ent)
                    rel.subj_ent = new_ent
                elif rel.obj_ent in merged_cxrgraph_ents:
                    if rel.obj_ent in rel.subj_ent.chain_info[rel.label]["to"]:
                        rel.subj_ent.chain_info[rel.label]["to"].remove(rel.obj_ent)
                    if new_ent not in rel.subj_ent.chain_info[rel.label]["to"]:
                        rel.subj_ent.chain_info[rel.label]["to"].append(new_ent)
                    rel.obj_ent = new_ent

            for rel in rel_objs_tobe_removed:
                rel_list.remove(rel)

    assert len(rel_list) == len(set(rel_list)), f"{rel_list}"
    return ent_list, rel_list


def get_label_inheritance(cxrgraph_ents):
    candi_labels = [ent.label for ent in cxrgraph_ents]
    if "Observation-Absent" in candi_labels:
        return "Observation-Absent"
    elif "Observation-Uncertain" in candi_labels:
        return "Observation-Uncertain"
    elif "Observation-Present" in candi_labels:
        return "Observation-Present"
    elif "Anatomy" in candi_labels:
        return "Anatomy"
    else:
        return "Location-Attribute"


def get_attr_inheritance(cxrgraph_ents):
    candi_attr_normal = [ent.attr_normal for ent in cxrgraph_ents]
    candi_attr_action = [ent.attr_action for ent in cxrgraph_ents]
    candi_attr_change = [ent.attr_change for ent in cxrgraph_ents]
    assert all([i[0].istitle() for i in candi_attr_change]), f"{candi_attr_change} {candi_attr_normal} {candi_attr_action}"

    output_attr = {"normality": "NA", "action": "NA", "change": "NA"}
    if "Normal" in candi_attr_normal:
        output_attr["normality"] = "Normal"
    elif "Abnormal" in candi_attr_normal:
        output_attr["normality"] = "Abnormal"

    if "Essential" in candi_attr_action:
        output_attr["action"] = "Essential"
    elif "Removable" in candi_attr_action:
        output_attr["action"] = "Removable"

    if "Positive" in candi_attr_change:
        output_attr["change"] = "Positive"
    elif "Negative" in candi_attr_change:
        output_attr["change"] = "Negative"
    elif "Unchanged" in candi_attr_change:
        output_attr["change"] = "Unchanged"

    return output_attr


def check_span_relation(ent_a_indices, ent_b_indices):
    if ent_a_indices[1] <= ent_b_indices[0]:
        return "before"
    elif ent_a_indices[0] >= ent_b_indices[1]:
        return "after"
    elif ent_a_indices[0] == ent_b_indices[0] and ent_a_indices[1] == ent_b_indices[1]:
        return "equal"
    elif ent_a_indices[0] <= ent_b_indices[0] and ent_b_indices[1] <= ent_a_indices[1]:
        return "contain"
    elif ent_b_indices[0] <= ent_a_indices[0] and ent_a_indices[1] <= ent_b_indices[1]:
        return "inside"
    else:
        return "overlap"

In [4]:
class SentenceRepresentation:
    def __init__(self, doc_key, sent_id, sent_text):
        self.doc_key = doc_key
        self.sent_id = sent_id
        self.sent_text = sent_text
        self.ent_tuples = []  # (tok_str, label, attr_normal, attr_action, attr_change)
        self.rel_tuples = []  # (subj_tok_str, label, obj_tok_str)
        self.normal = []  # Observation-Present and Normal
        self.abnormal = []  # Observation-Present and Abnormal
        self.absent = []  # Observation-Absent
        self.uncertain = []  # Observation-Uncertain

    def set_sent_repr(self, linked_graphs):
        for linked_graph in linked_graphs:
            # 这部分是方法1 (7_construct_graph.ipynb)，直接从linked_graph中获取实体和关系，只对graph进行了最基础的处理。
            # 从5_1_fsdp_peft_full_graph_text.py实验结果推测是graph太复杂，包含了太多的杂乱信息，导致模型无法有效学习。
            for ent in linked_graph.ents:
                if ent.attr_action == "Removable":
                    continue
                self.ent_tuples.append((ent.tok_str, ent.label, ent.attr_normal, ent.attr_action, ent.attr_change))

            for rel in linked_graph.rels:
                if rel.subj_ent.attr_action == "Removable" or rel.obj_ent.attr_action == "Removable":
                    continue
                self.rel_tuples.append((rel.subj_ent.tok_str, rel.label, rel.obj_ent.tok_str))

            # 这部分是方法2：想办法对graph进行简化，避免过于复杂的graph导致模型无法学习。

            # Removable 的 ents 将通过 chain_info 进行截断，即路径中直接或间接指向Removable的实体将被删除
            removable_ents = []
            for ent in linked_graph.ents:
                if ent.attr_action == "Removable":
                    removable_ents.append(ent)
                    removable_ents.extend(collect_from_path_ents_via_chain_info(ent))

            # 疾病实体分为正常、异常、缺失和不确定四类：
            is_normal = False
            is_absent = False
            is_uncertain = False
            repr_nodes = []
            for ent in linked_graph.ents:
                if ent in removable_ents:
                    continue
                if ent.label == "Observation-Absent":
                    is_absent = True
                if ent.label == "Observation-Present" and ent.attr_normal == "Normal":
                    is_normal = True
                if ent.label == "Observation-Uncertain" or ent.chain_info["suggestive_of"]["to"]:
                    # 如果ent的label是Observation-Uncertain，或者ents之间存在suggestive_of关系（任意一个ent有就可以，因为rel都是成对的）
                    is_uncertain = True
                repr_nodes.append(ent)

            # 此外，额外增加一些规则来处理数据：
            # 构建 modify_group：所有通过 modify 连接的实体合并为一个组（连通分量），组内实体按 tok_indices[0] 升序排序。
            # 组间依赖关系：见 reorder_entities()
            reorder_nodes = reorder_entities(repr_nodes)
            repr_str_list = [node if isinstance(node, str) else node.tok_str.lower() for node in reorder_nodes]  # 如果isinstance(node, str)==True，则当前node是特殊的关系字符串：e.g. <suggestive_of>

            if is_uncertain:
                self.uncertain.append(" ".join(repr_str_list))
            elif is_absent:
                self.absent.append(" ".join(repr_str_list))
            elif is_normal:
                self.normal.append(" ".join(repr_str_list))
            else:
                self.abnormal.append(" ".join(repr_str_list))

    def __repr__(self) -> str:
        return f"{self.sent_text}"

    def __str__(self) -> str:
        return self.__repr__()


class ReprTuple:
    def __init__(self, repr_tuples):
        self.repr_tuples = tuple(sorted(tuple(sorted(inner)) if isinstance(inner, list) else inner for inner in repr_tuples))

    def __hash__(self):
        # 直接对已经排序的嵌套元组进行哈希
        return hash(self.repr_tuples)

    def __eq__(self, other):
        # 比较两个对象的 repr_tuples 是否相等
        if isinstance(other, ReprTuple):
            return self.repr_tuples == other.repr_tuples
        return False

    def __repr__(self):
        return f"ReprTuple({self.repr_tuples})"


def collect_from_path_ents_via_chain_info(entity, visited=None):
    if visited is None:
        visited = set()

    # 防止重复访问
    if entity in visited:
        return set()

    visited.add(entity)
    result = set()

    for relation_type, from_to_dict in entity.chain_info.items():
        for ent in from_to_dict["from"]:
            if ent not in visited:
                result.add(ent)
                result.update(collect_from_path_ents_via_chain_info(ent, visited))

    return result

In [5]:
# 这部分代码由ChatGpt迭代生成

from collections import defaultdict, deque

suggestive_of_label = "<suggestive_of>"
localed_at_label = "<located_at>"


def reorder_entities(ent_list):
    """
    构建 modify + part_of group 并排序实体：
    - group 内部先按 part_of 拓扑排序（若无 part_of 边，则直接按 tok_indices[0] 排序）；
      若检测到环，会自动断开一条 part_of 边后重试；若仍无解，则退回按 tok_indices 排序。
    - 对每个 group 调用 prune_redundant_head_entities() 删除冗余的单词实体。
    - group 间依赖关系：
        A part_of B ⇒ A 在 B 后
        A located_at B ⇒ A 在 B 前 ⇒ 插入标签
        A suggestive_of B ⇒ A 在 B 前 ⇒ 插入标签
      若两种关系都存在，则依次插入 "SUGG"、"LOC"。
    - group 间拓扑排序中的循环检测与打断，遵循组件原始顺序；
    """
    # Step 1: 构建 modify + part_of 连通分量（核心 group）
    visited = set()
    components = []

    def dfs(ent, group):
        visited.add(ent)
        group.append(ent)
        neighbors = ent.chain_info["modify"]["to"] + ent.chain_info["modify"]["from"] + ent.chain_info["part_of"]["to"] + ent.chain_info["part_of"]["from"]
        for neighbor in neighbors:
            if neighbor in ent_list and neighbor not in visited:
                dfs(neighbor, group)

    for ent in ent_list:
        if ent not in visited:
            group = []
            dfs(ent, group)
            group = sort_group_by_part_of(group)
            group = prune_redundant_head_entities(group)
            components.append(group)

    # Step 2: 构建 group 级别的依赖图
    ent_to_group = {ent: i for i, group in enumerate(components) for ent in group}
    graph = defaultdict(set)
    in_degree = defaultdict(int)  # 记录每个 group 的入度
    edge_types = defaultdict(set)  # (from_group, to_group): {"suggestive_of", "located_at", ...}

    def add_edge(from_ent, to_ent, rel_type, reverse=False):
        if from_ent in ent_to_group and to_ent in ent_to_group:
            u = ent_to_group[from_ent]
            v = ent_to_group[to_ent]
            if reverse:
                u, v = v, u
            if u != v and v not in graph[u]:
                graph[u].add(v)
                in_degree[v] += 1
                edge_types[(u, v)].add(rel_type)

    for group in components:
        for ent in group:
            for to_ent in ent.chain_info["part_of"]["to"]:
                # A part_of B， B → A
                add_edge(to_ent, ent, "part_of")
            for to_ent in ent.chain_info["located_at"]["to"]:
                # A located_at B， A → B
                add_edge(ent, to_ent, "located_at")
            for to_ent in ent.chain_info["suggestive_of"]["to"]:
                # A suggestive_of B， A → B
                add_edge(ent, to_ent, "suggestive_of")

    # Step 3: 对 group 进行拓扑排序并拼接结果
    # 支持 group 间拓扑排序中的循环检测与打断，遵循组件原始顺序；
    def attempt_topo_sort(edges_to_ignore=None):
        in_deg_copy = defaultdict(int)
        graph_copy = defaultdict(set)
        for u in graph:
            for v in graph[u]:
                if edges_to_ignore and (u, v) in edges_to_ignore:
                    continue
                graph_copy[u].add(v)
                in_deg_copy[v] += 1

        for i in range(len(components)):
            in_deg_copy[i] = in_deg_copy.get(i, 0)

        queue = deque(i for i in range(len(components)) if in_deg_copy[i] == 0)
        sorted_ids = []
        while queue:
            i = queue.popleft()
            sorted_ids.append(i)
            for j in graph_copy[i]:
                in_deg_copy[j] -= 1
                if in_deg_copy[j] == 0:
                    queue.append(j)
        return sorted_ids

    sorted_group_ids = attempt_topo_sort()
    if len(sorted_group_ids) != len(components):
        all_edges = {(u, v) for u in graph for v in graph[u]}
        group_order = {i: idx for idx, i in enumerate(range(len(components)))}
        candidate_edges = [edge for edge in all_edges if group_order[edge[0]] > group_order[edge[1]]]

        for edge in candidate_edges:
            trial_ids = attempt_topo_sort(edges_to_ignore={edge})
            if len(trial_ids) == len(components):
                print(f"Removed backward edge to break cycle: group{edge[0]} → group{edge[1]}")
                sorted_group_ids = trial_ids
                break
        else:
            print("Unresolvable inter-group cycle — fallback to original group order")
            sorted_group_ids = list(range(len(components)))

    result = []
    for idx, gid in enumerate(sorted_group_ids):
        result.extend(components[gid])
        if idx < len(sorted_group_ids) - 1:
            next_gid = sorted_group_ids[idx + 1]
            key = (gid, next_gid)
            if key in edge_types:
                if "suggestive_of" in edge_types[key]:
                    result.append(suggestive_of_label)
                if "located_at" in edge_types[key]:
                    result.append(localed_at_label)
    return result


def sort_group_by_part_of(group):
    """
    对单个 group 内部进行排序：
    1. 构建基于 part_of 的有向图（edge: B → A, 表示 A part_of B）
    2. 尝试 Kahn 拓扑排序，每次从“零入度集合”中选择 tok_indices[0] 最小的实体；
       若能排完全部节点，直接返回 topo 顺序。
    3. 若出现 cycle，尝试逐条移除 part_of 边并重试（第一条能使图无环即停止），
       并打印被移除的边信息；若所有单边移除仍无法排序，则 fallback 回
       tok_indices[0] 排序。
    """
    ent_to_index = {ent: i for i, ent in enumerate(group)}

    def build_graph(edges_to_ignore=None):
        graph = defaultdict(set)
        in_deg = defaultdict(int)
        for ent in group:
            for to_ent in ent.chain_info["part_of"]["to"]:
                if to_ent in ent_to_index:
                    # 建边：B → A
                    if edges_to_ignore and (ent, to_ent) in edges_to_ignore:
                        continue
                    graph[to_ent].add(ent)
                    in_deg[ent] += 1
        return graph, in_deg

    # 尝试一次正常拓扑排序
    def attempt_topo(edges_to_ignore=None):
        graph, in_deg = build_graph(edges_to_ignore)
        zero = [ent for ent in group if in_deg[ent] == 0]
        zero.sort(key=lambda x: x.tok_indices[0])

        sorted_group = []
        in_deg = in_deg.copy()

        while zero:
            # 每次取 tok_indices[0] 最小的实体
            ent = zero.pop(0)
            sorted_group.append(ent)
            for nbr in list(graph[ent]):
                in_deg[nbr] -= 1
                if in_deg[nbr] == 0:
                    zero.append(nbr)
            zero.sort(key=lambda x: x.tok_indices[0])

        return sorted_group

    # 第一次尝试
    sorted_group = attempt_topo(edges_to_ignore=None)
    if len(sorted_group) == len(group):
        return sorted_group

    # 检测到环，收集所有 part_of 边
    all_edges = set()
    for ent in group:
        for to_ent in ent.chain_info["part_of"]["to"]:
            if to_ent in ent_to_index:
                all_edges.add((ent, to_ent))

    # 逐条移除边并重试
    for removed_edge in all_edges:
        trial = attempt_topo(edges_to_ignore={removed_edge})
        if len(trial) == len(group):
            # 成功打破环
            print(f"Removed part_of edge to break cycle: " f"{removed_edge[0].tok_str} part_of {removed_edge[1].tok_str}")
            return trial

    # 仍无法打破环，fallback：按 tok_indices 升序排序
    print("Unresolvable cycle in part_of within group—fallback to tok_indices order")
    return sorted(group, key=lambda x: x.tok_indices[0])


def prune_redundant_head_entities(group):
    """由于graph是pred的，存在错误，因此使用规则来移除明显的错误
    在单个 group 内部，如果 ent 满足：
      1. tok_str 是一个单词（无空格）; 只有一个 “to/from” 连接; 该 tok_str 出现在 group 中后续任意一个实体的 tok_str 里
      4. 该 tok_str 是一个 stop word
    则将其从 group 中移除。
    """
    to_remove = set()
    for i, ent in enumerate(group):
        if ent.tok_str in nlp.Defaults.stop_words:
            to_remove.add(ent)
            continue

        if " " in ent.tok_str:
            continue

        # 统计所有 to/from 连接
        all_links = []
        for rel in ["modify", "part_of", "located_at", "suggestive_of"]:
            all_links.extend(ent.chain_info[rel]["to"])
            all_links.extend(ent.chain_info[rel]["from"])
        if len(all_links) != 1:
            continue

        # 如果 ent.tok_str 出现在后续任何实体的 tok_str 中，就删
        if any(ent.tok_str in other.tok_str for other in group[i + 1 :]):
            to_remove.add(ent)

    for ent in to_remove:
        group.remove(ent)

    return group

In [6]:
def add_graph_repr(doc):
    reprs = []
    for split_sent_idx, (cxrgraph_ent, cxrgraph_rel, cxrgraph_attr, radlex) in enumerate(zip(doc["cxrgraph_ent"], doc["cxrgraph_rel"], doc["cxrgraph_attr"], doc["radlex"])):

        sent_repr = SentenceRepresentation(doc_key=doc["doc_key"], sent_id=split_sent_idx, sent_text=doc["split_sents"][split_sent_idx])

        # resolve ent and rel from json
        ent_list, rel_list = resolve_ent_rel(split_sent_idx, cxrgraph_ent, cxrgraph_rel, cxrgraph_attr, radlex)

        linked_graphs = []
        visited_ents = set()
        for ent in ent_list:
            if ent not in visited_ents:
                sent_ents = []
                search_linked_ents(ent, visited_ents, sent_ents)
                sent_graph = LinkedGraph(sent_ents)
                sent_graph.get_involved_rels(rel_list)
                linked_graphs.append(sent_graph)

        sent_repr.set_sent_repr(linked_graphs)
        reprs.append(sent_repr)

    doc["graph_reprs2"] = []
    for sent_repr in reprs:
        doc["graph_reprs2"].append(
            {
                "normal": sent_repr.normal,
                "abnormal": sent_repr.abnormal,
                "absent": sent_repr.absent,
                "uncertain": sent_repr.uncertain,
            }
        )
    return doc

# Create dataset

In [7]:
input_dir = "/home/yuxiang/liao/workspace/arrg_preprocessing/outputs/interpret_text/all/"
output_dir = "/home/yuxiang/liao/workspace/arrg_preprocessing/outputs/interpret_graph/4_labels/"

for section_name in ["findings", "impression"]:
    ds_text = load_from_disk(os.path.join(input_dir, f"interpret_text_{section_name}"))
    ds_graph = ds_text.map(add_graph_repr)
    ds_graph.save_to_disk(os.path.join(output_dir, f"interpret_graph_{section_name}"))

Saving the dataset (6/6 shards): 100%|██████████| 343738/343738 [00:01<00:00, 224709.48 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 8825/8825 [00:00<00:00, 224193.12 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 2692/2692 [00:00<00:00, 133331.76 examples/s]
Saving the dataset (6/6 shards): 100%|██████████| 365565/365565 [00:02<00:00, 137905.89 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 9308/9308 [00:00<00:00, 70308.30 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 2965/2965 [00:00<00:00, 179673.65 examples/s]
