diff --git a/causallearn/graph/Edge.py b/causallearn/graph/Edge.py index 73a0e89c..a277b541 100644 --- a/causallearn/graph/Edge.py +++ b/causallearn/graph/Edge.py @@ -82,28 +82,28 @@ def set_endpoint1(self, endpoint: Endpoint): self.endpoint1 = endpoint if self.numerical_endpoint_1 == 1 and self.numerical_endpoint_2 == 1: - if endpoint is Endpoint.ARROW: + if endpoint == Endpoint.ARROW: pass else: - if endpoint is Endpoint.TAIL: + if endpoint == Endpoint.TAIL: self.numerical_endpoint_1 = -1 self.numerical_endpoint_2 = 1 else: - if endpoint is Endpoint.CIRCLE: + if endpoint == Endpoint.CIRCLE: self.numerical_endpoint_1 = 2 self.numerical_endpoint_2 = 1 else: - if endpoint is Endpoint.ARROW and self.numerical_endpoint_2 == 1: + if endpoint == Endpoint.ARROW and self.numerical_endpoint_2 == 1: self.numerical_endpoint_1 = 1 self.numerical_endpoint_2 = 1 else: - if endpoint is Endpoint.ARROW: + if endpoint == Endpoint.ARROW: self.numerical_endpoint_1 = 1 else: - if endpoint is Endpoint.TAIL: + if endpoint == Endpoint.TAIL: self.numerical_endpoint_1 = -1 else: - if endpoint is Endpoint.CIRCLE: + if endpoint == Endpoint.CIRCLE: self.numerical_endpoint_1 = 2 if self.pointing_left(self.endpoint1, self.endpoint2): @@ -123,28 +123,28 @@ def set_endpoint2(self, endpoint: Endpoint): self.endpoint2 = endpoint if self.numerical_endpoint_1 == 1 and self.numerical_endpoint_2 == 1: - if endpoint is Endpoint.ARROW: + if endpoint == Endpoint.ARROW: pass else: - if endpoint is Endpoint.TAIL: + if endpoint == Endpoint.TAIL: self.numerical_endpoint_1 = 1 self.numerical_endpoint_2 = -1 else: - if endpoint is Endpoint.CIRCLE: + if endpoint == Endpoint.CIRCLE: self.numerical_endpoint_1 = 1 self.numerical_endpoint_2 = 2 else: - if endpoint is Endpoint.ARROW and self.numerical_endpoint_2 == 1: + if endpoint == Endpoint.ARROW and self.numerical_endpoint_2 == 1: self.numerical_endpoint_1 = 1 self.numerical_endpoint_2 = 1 else: - if endpoint is Endpoint.ARROW: + if endpoint == Endpoint.ARROW: self.numerical_endpoint_2 = 1 else: - if endpoint is Endpoint.TAIL: + if endpoint == Endpoint.TAIL: self.numerical_endpoint_2 = -1 else: - if endpoint is Endpoint.CIRCLE: + if endpoint == Endpoint.CIRCLE: self.numerical_endpoint_2 = 2 if self.pointing_left(self.endpoint1, self.endpoint2): @@ -216,20 +216,20 @@ def __str__(self): edge_string = node1.get_name() + " " - if endpoint1 is Endpoint.TAIL: + if endpoint1 == Endpoint.TAIL: edge_string = edge_string + "-" else: - if endpoint1 is Endpoint.ARROW: + if endpoint1 == Endpoint.ARROW: edge_string = edge_string + "<" else: edge_string = edge_string + "o" edge_string = edge_string + "-" - if endpoint2 is Endpoint.TAIL: + if endpoint2 == Endpoint.TAIL: edge_string = edge_string + "-" else: - if endpoint2 is Endpoint.ARROW: + if endpoint2 == Endpoint.ARROW: edge_string = edge_string + ">" else: edge_string = edge_string + "o" diff --git a/causallearn/graph/Edges.py b/causallearn/graph/Edges.py index 3a47713b..ff8ccbb4 100644 --- a/causallearn/graph/Edges.py +++ b/causallearn/graph/Edges.py @@ -30,29 +30,29 @@ def undirected_edge(self, node_a: Node, node_b: Node) -> Edge: # return true iff an edge is a bidrected edge <-> def is_bidirected_edge(self, edge: Edge) -> bool: - return edge.get_endpoint1() is Endpoint.ARROW and edge.get_endpoint2() is Endpoint.ARROW + return edge.get_endpoint1() == Endpoint.ARROW and edge.get_endpoint2() == Endpoint.ARROW # return true iff the given edge is a directed edge --> def is_directed_edge(self, edge: Edge) -> bool: - if edge.get_endpoint1() is Endpoint.TAIL: - return edge.get_endpoint2() is Endpoint.ARROW - elif edge.get_endpoint2() is Endpoint.TAIL: - return edge.get_endpoint1() is Endpoint.ARROW + if edge.get_endpoint1() == Endpoint.TAIL: + return edge.get_endpoint2() == Endpoint.ARROW + elif edge.get_endpoint2() == Endpoint.TAIL: + return edge.get_endpoint1() == Endpoint.ARROW else: return False # return true iff the given edge is a partially oriented edge o-> def is_partially_oriented_edge(self, edge: Edge) -> bool: - if edge.get_endpoint1() is Endpoint.CIRCLE: - return edge.get_endpoint2() is Endpoint.ARROW - elif edge.get_endpoint2() is Endpoint.CIRCLE: - return edge.get_endpoint1() is Endpoint.ARROW + if edge.get_endpoint1() == Endpoint.CIRCLE: + return edge.get_endpoint2() == Endpoint.ARROW + elif edge.get_endpoint2() == Endpoint.CIRCLE: + return edge.get_endpoint1() == Endpoint.ARROW else: return False # return true iff some edge is an undirected edge -- def is_undirected_edge(self, edge: Edge) -> bool: - return edge.get_endpoint1() is Endpoint.TAIL and edge.get_endpoint2() is Endpoint.TAIL + return edge.get_endpoint1() == Endpoint.TAIL and edge.get_endpoint2() == Endpoint.TAIL def traverse_directed(self, node: Node, edge: Edge) -> Node | None: if node == edge.get_node1(): diff --git a/causallearn/graph/Endpoint.py b/causallearn/graph/Endpoint.py index 31efe5d1..df551955 100644 --- a/causallearn/graph/Endpoint.py +++ b/causallearn/graph/Endpoint.py @@ -18,3 +18,6 @@ class Endpoint(Enum): # Prints out the name of the type def __str__(self): return self.name + + def __eq__(self, other): + return self.value == other.value diff --git a/causallearn/utils/GraphUtils.py b/causallearn/utils/GraphUtils.py index c63ada52..9a9497d6 100644 --- a/causallearn/utils/GraphUtils.py +++ b/causallearn/utils/GraphUtils.py @@ -62,20 +62,20 @@ def edge_string(self, edge: Edge) -> str: edge_string = node1.get_name() + " " - if endpoint1 is Endpoint.TAIL: + if endpoint1 == Endpoint.TAIL: edge_string = edge_string + "-" else: - if endpoint1 is Endpoint.ARROW: + if endpoint1 == Endpoint.ARROW: edge_string = edge_string + "<" else: edge_string = edge_string + "o" edge_string = edge_string + "-" - if endpoint2 is Endpoint.TAIL: + if endpoint2 == Endpoint.TAIL: edge_string = edge_string + "-" else: - if endpoint2 is Endpoint.ARROW: + if endpoint2 == Endpoint.ARROW: edge_string = edge_string + ">" else: edge_string = edge_string + "o"