Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions causallearn/graph/GeneralGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def adjust_dpath(self, i: int, j: int):
self.dpath = dpath

def reconstitute_dpath(self, edges: List[Edge]):
self.dpath = np.zeros((self.num_vars, self.num_vars), np.dtype(int))
for i in range(self.num_vars):
self.adjust_dpath(i, i)

Expand All @@ -73,7 +74,11 @@ def reconstitute_dpath(self, edges: List[Edge]):
node2 = edge.get_node2()
i = self.node_map[node1]
j = self.node_map[node2]
self.adjust_dpath(i, j)
if self.is_parent_of(node1, node2):
self.adjust_dpath(i, j)
elif self.is_parent_of(node2, node1):
self.adjust_dpath(j, i)


def collect_ancestors(self, node: Node, ancestors: List[Node]):
if node in ancestors:
Expand Down Expand Up @@ -503,13 +508,13 @@ def is_ancestor_of(self, node1: Node, node2: Node) -> bool:
def is_child_of(self, node1: Node, node2: Node) -> bool:
i = self.node_map[node1]
j = self.node_map[node2]
return self.graph[i, j] == 1 or self.graph[i, j] == Endpoint.ARROW_AND_ARROW.value
return self.graph[i, j] == Endpoint.TAIL.value or self.graph[i, j] == Endpoint.ARROW_AND_ARROW.value

# Returns true iff node1 is a parent of node2.
def is_parent_of(self, node1: Node, node2: Node) -> bool:
i = self.node_map[node1]
j = self.node_map[node2]
return self.graph[j, i] == 1 or self.graph[j, i] == Endpoint.ARROW_AND_ARROW.value
return self.graph[j, i] == Endpoint.ARROW.value and self.graph[i, j] == Endpoint.TAIL.value

# Returns true iff node1 is a proper ancestor of node2.
def is_proper_ancestor_of(self, node1: Node, node2: Node) -> bool:
Expand All @@ -521,9 +526,7 @@ def is_proper_descendant_of(self, node1: Node, node2: Node) -> bool:

# Returns true iff node1 is a descendant of node2.
def is_descendant_of(self, node1: Node, node2: Node) -> bool:
i = self.node_map[node1]
j = self.node_map[node2]
return self.dpath[i, j] == 1
return self.is_ancestor_of(node2, node1)

# Returns the edge connecting node1 and node2, provided a unique such edge exists.
def get_edge(self, node1: Node, node2: Node) -> Edge | None:
Expand Down Expand Up @@ -763,6 +766,8 @@ def remove_edge(self, edge: Edge):
end1 = edge.get_numerical_endpoint1()
end2 = edge.get_numerical_endpoint2()

is_fully_directed = self.is_parent_of(node1, node2) or self.is_parent_of(node2, node1)

if out_of == Endpoint.TAIL_AND_ARROW.value and in_to == Endpoint.TAIL_AND_ARROW.value:
if end1 == Endpoint.ARROW.value:
self.graph[j, i] = -1
Expand Down Expand Up @@ -794,7 +799,8 @@ def remove_edge(self, edge: Edge):
self.graph[j, i] = 0
self.graph[i, j] = 0

self.reconstitute_dpath(self.get_graph_edges())
if is_fully_directed:
self.reconstitute_dpath(self.get_graph_edges())

# Removes the edge connecting the given two nodes, provided there is exactly one such edge.
def remove_connecting_edge(self, node1: Node, node2: Node):
Expand Down
Loading