From 98ff97929cec59ea120c7b082acfe971d1e7e4ec Mon Sep 17 00:00:00 2001 From: kenneth-lee-ch Date: Thu, 15 Feb 2024 03:17:35 -0500 Subject: [PATCH 1/2] added rules 8,9,10 for FCI Signed-off-by: kenneth-lee-ch --- causallearn/search/ConstraintBased/FCI.py | 168 +++++++++++++++++++++- 1 file changed, 167 insertions(+), 1 deletion(-) diff --git a/causallearn/search/ConstraintBased/FCI.py b/causallearn/search/ConstraintBased/FCI.py index 57d2f17..1d09fbf 100644 --- a/causallearn/search/ConstraintBased/FCI.py +++ b/causallearn/search/ConstraintBased/FCI.py @@ -15,7 +15,7 @@ from causallearn.utils.cit import * from causallearn.utils.FAS import fas from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge - +from itertools import combinations def traverseSemiDirected(node: Node, edge: Edge) -> Node | None: if node == edge.get_node1(): @@ -542,6 +542,142 @@ def ruleR4B(graph: Graph, maxPathLength: int, data: ndarray, independence_test_m return change_flag + +def rule8(graph: Graph, nodes: List[Node]): + nodes = graph.get_nodes() + changeFlag = False + for node_B in nodes: + adj = graph.get_adjacent_nodes(node_B) + if len(adj) < 2: + continue + + cg = ChoiceGenerator(len(adj), 2) + combination = cg.next() + + while combination is not None: + node_A = adj[combination[0]] + node_C = adj[combination[1]] + combination = cg.next() + + if(graph.get_endpoint(node_A, node_B) == Endpoint.ARROW and graph.get_endpoint(node_B, node_A) == Endpoint.TAIL and \ + graph.get_endpoint(node_B, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_B) == Endpoint.TAIL and \ + graph.is_adjacent_to(node_A, node_C) and \ + graph.get_endpoint(node_A, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_A)== Endpoint.CIRCLE) or \ + (graph.get_endpoint(node_A, node_B) == Endpoint.CIRCLE and graph.get_endpoint(node_B, node_A) == Endpoint.TAIL and \ + graph.get_endpoint(node_B, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_B) == Endpoint.TAIL and \ + graph.is_adjacent_to(node_A, node_C) and \ + graph.get_endpoint(node_A, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_A)== Endpoint.CIRCLE): + edge1 = graph.get_edge(node_A, node_C) + graph.remove_edge(edge1) + graph.add_edge(Edge(node_A, node_C,Endpoint.TAIL, Endpoint.ARROW)) + changeFlag = True + + return changeFlag + + + +def is_possible_parent(graph: Graph, potential_parent_node, child_node): + if graph.node_map[potential_parent_node] == graph.node_map[child_node]: + return False + if not graph.is_adjacent_to(potential_parent_node, child_node): + return False + + if graph.get_endpoint(child_node, potential_parent_node) == Endpoint.ARROW or \ + graph.get_endpoint(potential_parent_node, child_node) == Endpoint.TAIL: + return False + else: + return True + + +def find_possible_children(graph: Graph, parent_node, en_nodes=None): + if en_nodes is None: + nodes = graph.get_nodes() + en_nodes = [node for node in nodes if graph.node_map[node] != graph.node_map[parent_node]] + + potential_child_nodes = set() + for potential_node in en_nodes: + if is_possible_parent(graph, potential_parent_node=parent_node, child_node=potential_node): + potential_child_nodes.add(potential_node) + + return potential_child_nodes + +def rule9(graph: Graph, nodes: List[Node]): + changeFlag = False + nodes = graph.get_nodes() + for node_C in nodes: + intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW) + for node_A in intoCArrows: + # we want A o--> C + if not graph.get_endpoint(node_C, node_A) == Endpoint.CIRCLE: + continue + + # look for a possibly directed uncovered path s.t. B and C are not connected (for the given A o--> C + a_node_idx = graph.node_map[node_A] + c_node_idx = graph.node_map[node_C] + a_adj_nodes = graph.get_adjacent_nodes(node_A) + nodes_set = [node for node in a_adj_nodes if graph.node_map[node] != a_node_idx and graph.node_map[node]!= c_node_idx] + possible_children = find_possible_children(graph, node_A, nodes_set) + for node_B in possible_children: + if graph.is_adjacent_to(node_B, node_C): + continue + if existsSemiDirectedPath(node_from=node_B, node_to=node_C, G=graph): + edge1 = graph.get_edge(node_A, node_C) + graph.remove_edge(edge1) + graph.add_edge(Edge(node_A, node_C, Endpoint.TAIL, Endpoint.ARROW)) + changeFlag = True + break #once we found it, break out since we have already oriented Ao->C to A->C, we want to find the next A + return changeFlag + + +def rule10(graph: Graph): + changeFlag = False + nodes = graph.get_nodes() + for node_C in nodes: + intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW) + if len(intoCArrows) < 2: + continue + # get all A where A o-> C + Anodes = [node_A for node_A in intoCArrows if graph.get_endpoint(node_C, node_A) == Endpoint.CIRCLE] + if len(Anodes) == 0: + continue + + for node_A in Anodes: + A_adj_nodes = graph.get_adjacent_nodes(node_A) + en_nodes = [i for i in A_adj_nodes if i is not node_C] + A_possible_children = find_possible_children(graph, parent_node=node_A, en_nodes=en_nodes) + if len(A_possible_children) < 2: + continue + + gen = ChoiceGenerator(len(intoCArrows), 2) + choice = gen.next() + while choice is not None: + node_B = intoCArrows[choice[0]] + node_D = intoCArrows[choice[1]] + + choice = gen.next() + # we want B->C<-D + if graph.get_endpoint(node_C, node_B) != Endpoint.TAIL: + continue + + if graph.get_endpoint(node_C, node_D) != Endpoint.TAIL: + continue + + for children in combinations(A_possible_children, 2): + child_one, child_two = children + if not existsSemiDirectedPath(node_from=child_one, node_to=node_B, G=graph) or \ + not existsSemiDirectedPath(node_from=child_two, node_to=node_D, G=graph): + continue + + if not graph.is_adjacent_to(child_one, child_two): + edge1 = graph.get_edge(node_A, node_C) + graph.remove_edge(edge1) + graph.add_edge(Edge(node_A, node_C, Endpoint.TAIL, Endpoint.ARROW)) + changeFlag = True + break #once we found it, break out since we have already oriented Ao->C to A->C, we want to find the next A + + return changeFlag + + def visibleEdgeHelperVisit(graph: Graph, node_c: Node, node_a: Node, node_b: Node, path: List[Node]) -> bool: if path.__contains__(node_a): return False @@ -692,6 +828,29 @@ def _contains_all(set_a: Set[Node], set_b: Set[Node]): + + + + + + + + + + + + + + + + + + + + + + + def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float = 0.05, depth: int = -1, max_path_length: int = -1, verbose: bool = False, background_knowledge: BackgroundKnowledge | None = None, show_progress: bool = True, **kwargs) -> Tuple[Graph, List[Edge]]: @@ -787,6 +946,13 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float = if verbose: print("Epoch") + # rule 8 + change_flag = rule8(graph,nodes) + # rule 9 + change_flag = rule9(graph, nodes) + # rule 10 + change_flag = rule10(graph) + graph.set_pag(True) edges = get_color_edges(graph) From a1b09191844faaa9195df5058c021928550321ad Mon Sep 17 00:00:00 2001 From: kenneth-lee-ch Date: Thu, 15 Feb 2024 03:19:05 -0500 Subject: [PATCH 2/2] clean up whitespace Signed-off-by: kenneth-lee-ch --- causallearn/search/ConstraintBased/FCI.py | 24 ----------------------- 1 file changed, 24 deletions(-) diff --git a/causallearn/search/ConstraintBased/FCI.py b/causallearn/search/ConstraintBased/FCI.py index 1d09fbf..44f46f9 100644 --- a/causallearn/search/ConstraintBased/FCI.py +++ b/causallearn/search/ConstraintBased/FCI.py @@ -827,30 +827,6 @@ def _contains_all(set_a: Set[Node], set_b: Set[Node]): break - - - - - - - - - - - - - - - - - - - - - - - - def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float = 0.05, depth: int = -1, max_path_length: int = -1, verbose: bool = False, background_knowledge: BackgroundKnowledge | None = None, show_progress: bool = True, **kwargs) -> Tuple[Graph, List[Edge]]: