From c3d23e5096c2f59b54ba38af9d69049a00b1102b Mon Sep 17 00:00:00 2001 From: Haoyue Dai Date: Thu, 13 Apr 2023 16:18:39 +0800 Subject: [PATCH 1/3] fixed endpoint comparison bug Signed-off-by: Haoyue Dai --- causallearn/graph/Endpoint.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/causallearn/graph/Endpoint.py b/causallearn/graph/Endpoint.py index 31efe5d1..2ac36695 100644 --- a/causallearn/graph/Endpoint.py +++ b/causallearn/graph/Endpoint.py @@ -18,3 +18,6 @@ class Endpoint(Enum): # Prints out the name of the type def __str__(self): return self.name + + def __eq__(self, other): + return self.name == other.name From 0ae3dfa67d3a2bc3e4f17b867742b339a9208b4d Mon Sep 17 00:00:00 2001 From: Haoyue Dai Date: Thu, 13 Apr 2023 17:14:51 +0800 Subject: [PATCH 2/3] fixed endpoint comparison bug (compare int value) Signed-off-by: Haoyue Dai --- causallearn/graph/Endpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/causallearn/graph/Endpoint.py b/causallearn/graph/Endpoint.py index 2ac36695..df551955 100644 --- a/causallearn/graph/Endpoint.py +++ b/causallearn/graph/Endpoint.py @@ -20,4 +20,4 @@ def __str__(self): return self.name def __eq__(self, other): - return self.name == other.name + return self.value == other.value From b3beba7370ee93dd45ea39ff20f06b35db74efef Mon Sep 17 00:00:00 2001 From: Haoyue Dai Date: Thu, 13 Apr 2023 19:15:17 +0800 Subject: [PATCH 3/3] replace all the "is Endpoint" with "== Endpoint" Signed-off-by: Haoyue Dai --- causallearn/graph/Edge.py | 36 ++++++++++++++++----------------- causallearn/graph/Edges.py | 20 +++++++++--------- causallearn/utils/GraphUtils.py | 8 ++++---- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/causallearn/graph/Edge.py b/causallearn/graph/Edge.py index 73a0e89c..a277b541 100644 --- a/causallearn/graph/Edge.py +++ b/causallearn/graph/Edge.py @@ -82,28 +82,28 @@ def set_endpoint1(self, endpoint: Endpoint): self.endpoint1 = endpoint if self.numerical_endpoint_1 == 1 and self.numerical_endpoint_2 == 1: - if endpoint is Endpoint.ARROW: + if endpoint == Endpoint.ARROW: pass else: - if endpoint is Endpoint.TAIL: + if endpoint == Endpoint.TAIL: self.numerical_endpoint_1 = -1 self.numerical_endpoint_2 = 1 else: - if endpoint is Endpoint.CIRCLE: + if endpoint == Endpoint.CIRCLE: self.numerical_endpoint_1 = 2 self.numerical_endpoint_2 = 1 else: - if endpoint is Endpoint.ARROW and self.numerical_endpoint_2 == 1: + if endpoint == Endpoint.ARROW and self.numerical_endpoint_2 == 1: self.numerical_endpoint_1 = 1 self.numerical_endpoint_2 = 1 else: - if endpoint is Endpoint.ARROW: + if endpoint == Endpoint.ARROW: self.numerical_endpoint_1 = 1 else: - if endpoint is Endpoint.TAIL: + if endpoint == Endpoint.TAIL: self.numerical_endpoint_1 = -1 else: - if endpoint is Endpoint.CIRCLE: + if endpoint == Endpoint.CIRCLE: self.numerical_endpoint_1 = 2 if self.pointing_left(self.endpoint1, self.endpoint2): @@ -123,28 +123,28 @@ def set_endpoint2(self, endpoint: Endpoint): self.endpoint2 = endpoint if self.numerical_endpoint_1 == 1 and self.numerical_endpoint_2 == 1: - if endpoint is Endpoint.ARROW: + if endpoint == Endpoint.ARROW: pass else: - if endpoint is Endpoint.TAIL: + if endpoint == Endpoint.TAIL: self.numerical_endpoint_1 = 1 self.numerical_endpoint_2 = -1 else: - if endpoint is Endpoint.CIRCLE: + if endpoint == Endpoint.CIRCLE: self.numerical_endpoint_1 = 1 self.numerical_endpoint_2 = 2 else: - if endpoint is Endpoint.ARROW and self.numerical_endpoint_2 == 1: + if endpoint == Endpoint.ARROW and self.numerical_endpoint_2 == 1: self.numerical_endpoint_1 = 1 self.numerical_endpoint_2 = 1 else: - if endpoint is Endpoint.ARROW: + if endpoint == Endpoint.ARROW: self.numerical_endpoint_2 = 1 else: - if endpoint is Endpoint.TAIL: + if endpoint == Endpoint.TAIL: self.numerical_endpoint_2 = -1 else: - if endpoint is Endpoint.CIRCLE: + if endpoint == Endpoint.CIRCLE: self.numerical_endpoint_2 = 2 if self.pointing_left(self.endpoint1, self.endpoint2): @@ -216,20 +216,20 @@ def __str__(self): edge_string = node1.get_name() + " " - if endpoint1 is Endpoint.TAIL: + if endpoint1 == Endpoint.TAIL: edge_string = edge_string + "-" else: - if endpoint1 is Endpoint.ARROW: + if endpoint1 == Endpoint.ARROW: edge_string = edge_string + "<" else: edge_string = edge_string + "o" edge_string = edge_string + "-" - if endpoint2 is Endpoint.TAIL: + if endpoint2 == Endpoint.TAIL: edge_string = edge_string + "-" else: - if endpoint2 is Endpoint.ARROW: + if endpoint2 == Endpoint.ARROW: edge_string = edge_string + ">" else: edge_string = edge_string + "o" diff --git a/causallearn/graph/Edges.py b/causallearn/graph/Edges.py index 3a47713b..ff8ccbb4 100644 --- a/causallearn/graph/Edges.py +++ b/causallearn/graph/Edges.py @@ -30,29 +30,29 @@ def undirected_edge(self, node_a: Node, node_b: Node) -> Edge: # return true iff an edge is a bidrected edge <-> def is_bidirected_edge(self, edge: Edge) -> bool: - return edge.get_endpoint1() is Endpoint.ARROW and edge.get_endpoint2() is Endpoint.ARROW + return edge.get_endpoint1() == Endpoint.ARROW and edge.get_endpoint2() == Endpoint.ARROW # return true iff the given edge is a directed edge --> def is_directed_edge(self, edge: Edge) -> bool: - if edge.get_endpoint1() is Endpoint.TAIL: - return edge.get_endpoint2() is Endpoint.ARROW - elif edge.get_endpoint2() is Endpoint.TAIL: - return edge.get_endpoint1() is Endpoint.ARROW + if edge.get_endpoint1() == Endpoint.TAIL: + return edge.get_endpoint2() == Endpoint.ARROW + elif edge.get_endpoint2() == Endpoint.TAIL: + return edge.get_endpoint1() == Endpoint.ARROW else: return False # return true iff the given edge is a partially oriented edge o-> def is_partially_oriented_edge(self, edge: Edge) -> bool: - if edge.get_endpoint1() is Endpoint.CIRCLE: - return edge.get_endpoint2() is Endpoint.ARROW - elif edge.get_endpoint2() is Endpoint.CIRCLE: - return edge.get_endpoint1() is Endpoint.ARROW + if edge.get_endpoint1() == Endpoint.CIRCLE: + return edge.get_endpoint2() == Endpoint.ARROW + elif edge.get_endpoint2() == Endpoint.CIRCLE: + return edge.get_endpoint1() == Endpoint.ARROW else: return False # return true iff some edge is an undirected edge -- def is_undirected_edge(self, edge: Edge) -> bool: - return edge.get_endpoint1() is Endpoint.TAIL and edge.get_endpoint2() is Endpoint.TAIL + return edge.get_endpoint1() == Endpoint.TAIL and edge.get_endpoint2() == Endpoint.TAIL def traverse_directed(self, node: Node, edge: Edge) -> Node | None: if node == edge.get_node1(): diff --git a/causallearn/utils/GraphUtils.py b/causallearn/utils/GraphUtils.py index c63ada52..9a9497d6 100644 --- a/causallearn/utils/GraphUtils.py +++ b/causallearn/utils/GraphUtils.py @@ -62,20 +62,20 @@ def edge_string(self, edge: Edge) -> str: edge_string = node1.get_name() + " " - if endpoint1 is Endpoint.TAIL: + if endpoint1 == Endpoint.TAIL: edge_string = edge_string + "-" else: - if endpoint1 is Endpoint.ARROW: + if endpoint1 == Endpoint.ARROW: edge_string = edge_string + "<" else: edge_string = edge_string + "o" edge_string = edge_string + "-" - if endpoint2 is Endpoint.TAIL: + if endpoint2 == Endpoint.TAIL: edge_string = edge_string + "-" else: - if endpoint2 is Endpoint.ARROW: + if endpoint2 == Endpoint.ARROW: edge_string = edge_string + ">" else: edge_string = edge_string + "o"