-
Notifications
You must be signed in to change notification settings - Fork 231
Description
Hi,
thanks for the great work on the package.
I think I found a bug in GeneralGraph.subgraph() (causallearn.graph.GeneralGraph) when building on top of the method.
My code:
from causallearn.graph.GeneralGraph import GeneralGraph
import numpy as np
_ , relevant_nodes = cdag.get_parents_plus(cluster3) # A list of nodes (node objects)
#cdag.cg.G.subgraph(relevant_nodes)
subgraph = GeneralGraph(relevant_nodes)
graph = cdag.cg.G.graph # ndarray
for i in range(cdag.cg.G.num_vars):
print(i)
if (not cdag.cg.G.nodes[i] in relevant_nodes):
print(cdag.cg.G.nodes[i].get_name())
graph = np.delete(graph, i, axis = 0)Throws error: index 8 is out of bound for axis 0 with size 8
My code is specific to my environment, but logically works the same as
import numpy as np
array = np.zeros((5,5))
for i in range(5):
for j in range(5):
array[i,j] = i+j
delete = [1,2,4]
for i in range(5):
if i in delete:
array = np.delete(array, i, axis=0)In causallearn, the graph is a ndarray, and iteratively deletes rows/columns. This causes an index out of bounds error, as the array gets smaller and so an index later on in the loop can be out of bounds.
Interestingly, when i directly restrict from the node list of the graph, i don't get an error:
from causallearn.graph.GraphClass import CausalGraph
test = CausalGraph(no_of_var=5, node_names=['X1','X2','X3','X4','X5'])
node_list = test.G.get_nodes()
restricted_nodes = node_list[0:2] + node_list[3:5]
subgraph = test.G.subgraph(restricted_nodes)Am i missing something or is this bugged?
A fix (which I submit as a pull request (#118) also) would be to change the code to:
def subgraph(self, nodes: List[Node]):
subgraph = GeneralGraph(nodes)
graph = self.graph
nodes_to_delete = []
for i in range(self.num_vars):
if not (self.nodes[i] in nodes):
nodes_to_delete .append(i)
graph = np.delete(graph, nodes_to_delete, axis = 0)
graph = np.delete(graph, nodes_to_delete, axis = 1)
subgraph.graph = graph
subgraph.reconstitute_dpath(subgraph.get_graph_edges())
return subgraphLet me know what you think.
Best,
Jan Marco