Skip to content

Commit

Permalink
Merge pull request #168 from kenneth-lee-ch/main
Browse files Browse the repository at this point in the history
Added rules 8, 9, 10 to FCI
  • Loading branch information
zhi-yi-huang committed Apr 5, 2024
2 parents 36b0829 + a1b0919 commit d98225b
Showing 1 changed file with 144 additions and 2 deletions.
146 changes: 144 additions & 2 deletions causallearn/search/ConstraintBased/FCI.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from causallearn.utils.cit import *
from causallearn.utils.FAS import fas
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge

from itertools import combinations

def traverseSemiDirected(node: Node, edge: Edge) -> Node | None:
if node == edge.get_node1():
Expand Down Expand Up @@ -542,6 +542,142 @@ def ruleR4B(graph: Graph, maxPathLength: int, data: ndarray, independence_test_m
return change_flag



def rule8(graph: Graph, nodes: List[Node]):
nodes = graph.get_nodes()
changeFlag = False
for node_B in nodes:
adj = graph.get_adjacent_nodes(node_B)
if len(adj) < 2:
continue

cg = ChoiceGenerator(len(adj), 2)
combination = cg.next()

while combination is not None:
node_A = adj[combination[0]]
node_C = adj[combination[1]]
combination = cg.next()

if(graph.get_endpoint(node_A, node_B) == Endpoint.ARROW and graph.get_endpoint(node_B, node_A) == Endpoint.TAIL and \
graph.get_endpoint(node_B, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_B) == Endpoint.TAIL and \
graph.is_adjacent_to(node_A, node_C) and \
graph.get_endpoint(node_A, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_A)== Endpoint.CIRCLE) or \
(graph.get_endpoint(node_A, node_B) == Endpoint.CIRCLE and graph.get_endpoint(node_B, node_A) == Endpoint.TAIL and \
graph.get_endpoint(node_B, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_B) == Endpoint.TAIL and \
graph.is_adjacent_to(node_A, node_C) and \
graph.get_endpoint(node_A, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_A)== Endpoint.CIRCLE):
edge1 = graph.get_edge(node_A, node_C)
graph.remove_edge(edge1)
graph.add_edge(Edge(node_A, node_C,Endpoint.TAIL, Endpoint.ARROW))
changeFlag = True

return changeFlag



def is_possible_parent(graph: Graph, potential_parent_node, child_node):
if graph.node_map[potential_parent_node] == graph.node_map[child_node]:
return False
if not graph.is_adjacent_to(potential_parent_node, child_node):
return False

if graph.get_endpoint(child_node, potential_parent_node) == Endpoint.ARROW or \
graph.get_endpoint(potential_parent_node, child_node) == Endpoint.TAIL:
return False
else:
return True


def find_possible_children(graph: Graph, parent_node, en_nodes=None):
if en_nodes is None:
nodes = graph.get_nodes()
en_nodes = [node for node in nodes if graph.node_map[node] != graph.node_map[parent_node]]

potential_child_nodes = set()
for potential_node in en_nodes:
if is_possible_parent(graph, potential_parent_node=parent_node, child_node=potential_node):
potential_child_nodes.add(potential_node)

return potential_child_nodes

def rule9(graph: Graph, nodes: List[Node]):
changeFlag = False
nodes = graph.get_nodes()
for node_C in nodes:
intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW)
for node_A in intoCArrows:
# we want A o--> C
if not graph.get_endpoint(node_C, node_A) == Endpoint.CIRCLE:
continue

# look for a possibly directed uncovered path s.t. B and C are not connected (for the given A o--> C
a_node_idx = graph.node_map[node_A]
c_node_idx = graph.node_map[node_C]
a_adj_nodes = graph.get_adjacent_nodes(node_A)
nodes_set = [node for node in a_adj_nodes if graph.node_map[node] != a_node_idx and graph.node_map[node]!= c_node_idx]
possible_children = find_possible_children(graph, node_A, nodes_set)
for node_B in possible_children:
if graph.is_adjacent_to(node_B, node_C):
continue
if existsSemiDirectedPath(node_from=node_B, node_to=node_C, G=graph):
edge1 = graph.get_edge(node_A, node_C)
graph.remove_edge(edge1)
graph.add_edge(Edge(node_A, node_C, Endpoint.TAIL, Endpoint.ARROW))
changeFlag = True
break #once we found it, break out since we have already oriented Ao->C to A->C, we want to find the next A
return changeFlag


def rule10(graph: Graph):
changeFlag = False
nodes = graph.get_nodes()
for node_C in nodes:
intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW)
if len(intoCArrows) < 2:
continue
# get all A where A o-> C
Anodes = [node_A for node_A in intoCArrows if graph.get_endpoint(node_C, node_A) == Endpoint.CIRCLE]
if len(Anodes) == 0:
continue

for node_A in Anodes:
A_adj_nodes = graph.get_adjacent_nodes(node_A)
en_nodes = [i for i in A_adj_nodes if i is not node_C]
A_possible_children = find_possible_children(graph, parent_node=node_A, en_nodes=en_nodes)
if len(A_possible_children) < 2:
continue

gen = ChoiceGenerator(len(intoCArrows), 2)
choice = gen.next()
while choice is not None:
node_B = intoCArrows[choice[0]]
node_D = intoCArrows[choice[1]]

choice = gen.next()
# we want B->C<-D
if graph.get_endpoint(node_C, node_B) != Endpoint.TAIL:
continue

if graph.get_endpoint(node_C, node_D) != Endpoint.TAIL:
continue

for children in combinations(A_possible_children, 2):
child_one, child_two = children
if not existsSemiDirectedPath(node_from=child_one, node_to=node_B, G=graph) or \
not existsSemiDirectedPath(node_from=child_two, node_to=node_D, G=graph):
continue

if not graph.is_adjacent_to(child_one, child_two):
edge1 = graph.get_edge(node_A, node_C)
graph.remove_edge(edge1)
graph.add_edge(Edge(node_A, node_C, Endpoint.TAIL, Endpoint.ARROW))
changeFlag = True
break #once we found it, break out since we have already oriented Ao->C to A->C, we want to find the next A

return changeFlag


def visibleEdgeHelperVisit(graph: Graph, node_c: Node, node_a: Node, node_b: Node, path: List[Node]) -> bool:
if path.__contains__(node_a):
return False
Expand Down Expand Up @@ -691,7 +827,6 @@ def _contains_all(set_a: Set[Node], set_b: Set[Node]):
break



def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float = 0.05, depth: int = -1,
max_path_length: int = -1, verbose: bool = False, background_knowledge: BackgroundKnowledge | None = None, show_progress: bool = True,
**kwargs) -> Tuple[Graph, List[Edge]]:
Expand Down Expand Up @@ -787,6 +922,13 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
if verbose:
print("Epoch")

# rule 8
change_flag = rule8(graph,nodes)
# rule 9
change_flag = rule9(graph, nodes)
# rule 10
change_flag = rule10(graph)

graph.set_pag(True)

edges = get_color_edges(graph)
Expand Down

0 comments on commit d98225b

Please sign in to comment.