diff --git a/causallearn/utils/DAG2PAG.py b/causallearn/utils/DAG2PAG.py index 581f5f20..9980b96b 100644 --- a/causallearn/utils/DAG2PAG.py +++ b/causallearn/utils/DAG2PAG.py @@ -96,12 +96,13 @@ def dag2pag(dag: Dag, islatent: List[Node]) -> GeneralGraph: data = np.empty(shape=(0, len(observed_nodes))) independence_test_method = CIT(data, method=d_separation, true_dag=true_dag) - + node_map = PAG.get_node_map() + sepset_reindexed = {(node_map[nodes[i]], node_map[nodes[j]]): sepset[(i, j)] for (i, j) in sepset} while change_flag: change_flag = False change_flag = rulesR1R2cycle(PAG, None, change_flag, False) - change_flag = ruleR3(PAG, sepset, None, change_flag, False) - change_flag = ruleR4B(PAG, -1, data, independence_test_method, 0.05, sep_sets=sepset, + change_flag = ruleR3(PAG, sepset_reindexed, None, change_flag, False) + change_flag = ruleR4B(PAG, -1, data, independence_test_method, 0.05, sep_sets=sepset_reindexed, change_flag=change_flag, bk=None, verbose=False) return PAG