Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions causallearn/graph/Edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,28 +82,28 @@ def set_endpoint1(self, endpoint: Endpoint):
self.endpoint1 = endpoint

if self.numerical_endpoint_1 == 1 and self.numerical_endpoint_2 == 1:
if endpoint is Endpoint.ARROW:
if endpoint == Endpoint.ARROW:
pass
else:
if endpoint is Endpoint.TAIL:
if endpoint == Endpoint.TAIL:
self.numerical_endpoint_1 = -1
self.numerical_endpoint_2 = 1
else:
if endpoint is Endpoint.CIRCLE:
if endpoint == Endpoint.CIRCLE:
self.numerical_endpoint_1 = 2
self.numerical_endpoint_2 = 1
else:
if endpoint is Endpoint.ARROW and self.numerical_endpoint_2 == 1:
if endpoint == Endpoint.ARROW and self.numerical_endpoint_2 == 1:
self.numerical_endpoint_1 = 1
self.numerical_endpoint_2 = 1
else:
if endpoint is Endpoint.ARROW:
if endpoint == Endpoint.ARROW:
self.numerical_endpoint_1 = 1
else:
if endpoint is Endpoint.TAIL:
if endpoint == Endpoint.TAIL:
self.numerical_endpoint_1 = -1
else:
if endpoint is Endpoint.CIRCLE:
if endpoint == Endpoint.CIRCLE:
self.numerical_endpoint_1 = 2

if self.pointing_left(self.endpoint1, self.endpoint2):
Expand All @@ -123,28 +123,28 @@ def set_endpoint2(self, endpoint: Endpoint):
self.endpoint2 = endpoint

if self.numerical_endpoint_1 == 1 and self.numerical_endpoint_2 == 1:
if endpoint is Endpoint.ARROW:
if endpoint == Endpoint.ARROW:
pass
else:
if endpoint is Endpoint.TAIL:
if endpoint == Endpoint.TAIL:
self.numerical_endpoint_1 = 1
self.numerical_endpoint_2 = -1
else:
if endpoint is Endpoint.CIRCLE:
if endpoint == Endpoint.CIRCLE:
self.numerical_endpoint_1 = 1
self.numerical_endpoint_2 = 2
else:
if endpoint is Endpoint.ARROW and self.numerical_endpoint_2 == 1:
if endpoint == Endpoint.ARROW and self.numerical_endpoint_2 == 1:
self.numerical_endpoint_1 = 1
self.numerical_endpoint_2 = 1
else:
if endpoint is Endpoint.ARROW:
if endpoint == Endpoint.ARROW:
self.numerical_endpoint_2 = 1
else:
if endpoint is Endpoint.TAIL:
if endpoint == Endpoint.TAIL:
self.numerical_endpoint_2 = -1
else:
if endpoint is Endpoint.CIRCLE:
if endpoint == Endpoint.CIRCLE:
self.numerical_endpoint_2 = 2

if self.pointing_left(self.endpoint1, self.endpoint2):
Expand Down Expand Up @@ -216,20 +216,20 @@ def __str__(self):

edge_string = node1.get_name() + " "

if endpoint1 is Endpoint.TAIL:
if endpoint1 == Endpoint.TAIL:
edge_string = edge_string + "-"
else:
if endpoint1 is Endpoint.ARROW:
if endpoint1 == Endpoint.ARROW:
edge_string = edge_string + "<"
else:
edge_string = edge_string + "o"

edge_string = edge_string + "-"

if endpoint2 is Endpoint.TAIL:
if endpoint2 == Endpoint.TAIL:
edge_string = edge_string + "-"
else:
if endpoint2 is Endpoint.ARROW:
if endpoint2 == Endpoint.ARROW:
edge_string = edge_string + ">"
else:
edge_string = edge_string + "o"
Expand Down
20 changes: 10 additions & 10 deletions causallearn/graph/Edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,29 +30,29 @@ def undirected_edge(self, node_a: Node, node_b: Node) -> Edge:

# return true iff an edge is a bidrected edge <->
def is_bidirected_edge(self, edge: Edge) -> bool:
return edge.get_endpoint1() is Endpoint.ARROW and edge.get_endpoint2() is Endpoint.ARROW
return edge.get_endpoint1() == Endpoint.ARROW and edge.get_endpoint2() == Endpoint.ARROW

# return true iff the given edge is a directed edge -->
def is_directed_edge(self, edge: Edge) -> bool:
if edge.get_endpoint1() is Endpoint.TAIL:
return edge.get_endpoint2() is Endpoint.ARROW
elif edge.get_endpoint2() is Endpoint.TAIL:
return edge.get_endpoint1() is Endpoint.ARROW
if edge.get_endpoint1() == Endpoint.TAIL:
return edge.get_endpoint2() == Endpoint.ARROW
elif edge.get_endpoint2() == Endpoint.TAIL:
return edge.get_endpoint1() == Endpoint.ARROW
else:
return False

# return true iff the given edge is a partially oriented edge o->
def is_partially_oriented_edge(self, edge: Edge) -> bool:
if edge.get_endpoint1() is Endpoint.CIRCLE:
return edge.get_endpoint2() is Endpoint.ARROW
elif edge.get_endpoint2() is Endpoint.CIRCLE:
return edge.get_endpoint1() is Endpoint.ARROW
if edge.get_endpoint1() == Endpoint.CIRCLE:
return edge.get_endpoint2() == Endpoint.ARROW
elif edge.get_endpoint2() == Endpoint.CIRCLE:
return edge.get_endpoint1() == Endpoint.ARROW
else:
return False

# return true iff some edge is an undirected edge --
def is_undirected_edge(self, edge: Edge) -> bool:
return edge.get_endpoint1() is Endpoint.TAIL and edge.get_endpoint2() is Endpoint.TAIL
return edge.get_endpoint1() == Endpoint.TAIL and edge.get_endpoint2() == Endpoint.TAIL

def traverse_directed(self, node: Node, edge: Edge) -> Node | None:
if node == edge.get_node1():
Expand Down
3 changes: 3 additions & 0 deletions causallearn/graph/Endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@ class Endpoint(Enum):
# Prints out the name of the type
def __str__(self):
return self.name

def __eq__(self, other):
return self.value == other.value
8 changes: 4 additions & 4 deletions causallearn/utils/GraphUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,20 @@ def edge_string(self, edge: Edge) -> str:

edge_string = node1.get_name() + " "

if endpoint1 is Endpoint.TAIL:
if endpoint1 == Endpoint.TAIL:
edge_string = edge_string + "-"
else:
if endpoint1 is Endpoint.ARROW:
if endpoint1 == Endpoint.ARROW:
edge_string = edge_string + "<"
else:
edge_string = edge_string + "o"

edge_string = edge_string + "-"

if endpoint2 is Endpoint.TAIL:
if endpoint2 == Endpoint.TAIL:
edge_string = edge_string + "-"
else:
if endpoint2 is Endpoint.ARROW:
if endpoint2 == Endpoint.ARROW:
edge_string = edge_string + ">"
else:
edge_string = edge_string + "o"
Expand Down