Skip to content

Commit 7dfd65b

Browse files
authored
Add bellman_ford_path (#64)
* Add `bellman_ford_path` * Make sure README.md lists all algorithms
1 parent 3cc4180 commit 7dfd65b

File tree

9 files changed

+194
-14
lines changed

9 files changed

+194
-14
lines changed

Diff for: README.md

+1
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ dispatch pattern shown above.
178178
- Shortest Paths
179179
- all_pairs_bellman_ford_path_length
180180
- all_pairs_shortest_path_length
181+
- bellman_ford_path
181182
- floyd_warshall
182183
- floyd_warshall_numpy
183184
- floyd_warshall_predecessor_and_distance

Diff for: graphblas_algorithms/algorithms/_bfs.py

+42-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""BFS routines used by other algorithms"""
22

33
import numpy as np
4-
from graphblas import Matrix, Vector, binary, replace, unary
4+
from graphblas import Matrix, Vector, binary, indexunary, replace, semiring, unary
55
from graphblas.semiring import any_pair
66

77

@@ -11,8 +11,13 @@ def _get_cutoff(n, cutoff):
1111
return cutoff + 1 # Inclusive
1212

1313

14-
def _plain_bfs(G, source, *, cutoff=None):
15-
index = G._key_to_id[source]
14+
def _bfs_plain(G, source=None, target=None, *, index=None, cutoff=None):
15+
if source is not None:
16+
index = G._key_to_id[source]
17+
if target is not None:
18+
dst_id = G._key_to_id[target]
19+
else:
20+
dst_id = None
1621
A = G.get_property("offdiag")
1722
n = A.nrows
1823
v = Vector(bool, n, name="bfs_plain")
@@ -25,6 +30,8 @@ def _plain_bfs(G, source, *, cutoff=None):
2530
q(~v.S, replace) << any_pair_bool(q @ A)
2631
if q.nvals == 0:
2732
break
33+
if dst_id is not None and dst_id in q:
34+
break
2835
v(q.S) << True
2936
return v
3037

@@ -83,8 +90,38 @@ def _bfs_levels(G, nodes, cutoff=None, *, dtype=int):
8390
return D
8491

8592

93+
def _bfs_parent(G, source, cutoff=None, *, target=None, transpose=False, dtype=int):
94+
if dtype == bool:
95+
dtype = int
96+
index = G._key_to_id[source]
97+
if target is not None:
98+
dst_id = G._key_to_id[target]
99+
else:
100+
dst_id = None
101+
A = G.get_property("offdiag")
102+
if transpose and G.is_directed():
103+
A = A.T # TODO: should we use "AT" instead?
104+
n = A.nrows
105+
v = Vector(dtype, n, name="bfs_parent")
106+
q = Vector(dtype, n, name="q")
107+
v[index] = index
108+
q[index] = index
109+
min_first = semiring.min_first[v.dtype]
110+
index = indexunary.index[v.dtype]
111+
cutoff = _get_cutoff(n, cutoff)
112+
for _i in range(1, cutoff):
113+
q(~v.S, replace) << min_first(q @ A)
114+
if q.nvals == 0:
115+
break
116+
v(q.S) << q
117+
if dst_id is not None and dst_id in q:
118+
break
119+
q << index(q)
120+
return v
121+
122+
86123
# TODO: benchmark this and the version commented out below
87-
def _plain_bfs_bidirectional(G, source):
124+
def _bfs_plain_bidirectional(G, source):
88125
# Bi-directional BFS w/o symmetrizing the adjacency matrix
89126
index = G._key_to_id[source]
90127
A = G.get_property("offdiag")
@@ -125,7 +162,7 @@ def _plain_bfs_bidirectional(G, source):
125162

126163

127164
"""
128-
def _plain_bfs_bidirectional(G, source):
165+
def _bfs_plain_bidirectional(G, source):
129166
# Bi-directional BFS w/o symmetrizing the adjacency matrix
130167
index = G._key_to_id[source]
131168
A = G.get_property("offdiag")
+3-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from .._bfs import _plain_bfs
1+
from .._bfs import _bfs_plain
22
from ..exceptions import PointlessConcept
33

44

55
def is_connected(G):
66
if len(G) == 0:
77
raise PointlessConcept("Connectivity is undefined for the null graph.")
8-
return _plain_bfs(G, next(iter(G))).nvals == len(G)
8+
return _bfs_plain(G, next(iter(G))).nvals == len(G)
99

1010

1111
def node_connected_component(G, n):
12-
return _plain_bfs(G, n)
12+
return _bfs_plain(G, n)
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from .._bfs import _plain_bfs_bidirectional
1+
from .._bfs import _bfs_plain_bidirectional
22
from ..exceptions import PointlessConcept
33

44

55
def is_weakly_connected(G):
66
if len(G) == 0:
77
raise PointlessConcept("Connectivity is undefined for the null graph.")
8-
return _plain_bfs_bidirectional(G, next(iter(G))).nvals == len(G)
8+
return _bfs_plain_bidirectional(G, next(iter(G))).nvals == len(G)

Diff for: graphblas_algorithms/algorithms/shortest_paths/weighted.py

+114-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import numpy as np
2-
from graphblas import Matrix, Vector, binary, monoid, replace, select, unary
2+
from graphblas import Matrix, Vector, binary, indexunary, monoid, replace, select, unary
33
from graphblas.semiring import any_pair, min_plus
44

5-
from .._bfs import _bfs_level, _bfs_levels
5+
from .._bfs import _bfs_level, _bfs_levels, _bfs_parent, _bfs_plain
66
from ..exceptions import Unbounded
77

88
__all__ = [
99
"single_source_bellman_ford_path_length",
10+
"bellman_ford_path",
1011
"bellman_ford_path_lengths",
1112
"negative_edge_cycle",
1213
]
@@ -164,6 +165,117 @@ def bellman_ford_path_lengths(G, nodes=None, *, expand_output=False):
164165
return D
165166

166167

168+
def _reconstruct_path_from_parents(G, parents, src, dst):
169+
indices, values = parents.to_coo(sort=False)
170+
d = dict(zip(indices.tolist(), values.tolist()))
171+
if dst not in d:
172+
return []
173+
cur = dst
174+
path = [cur]
175+
while cur != src:
176+
cur = d[cur]
177+
path.append(cur)
178+
return G.list_to_keys(reversed(path))
179+
180+
181+
def bellman_ford_path(G, source, target):
182+
src_id = G._key_to_id[source]
183+
dst_id = G._key_to_id[target]
184+
if G.get_property("is_iso"):
185+
# If the edges are iso-valued (and positive), then we can simply do level BFS
186+
is_negative = G.get_property("has_negative_edges+")
187+
if not is_negative:
188+
p = _bfs_parent(G, source, target=target)
189+
return _reconstruct_path_from_parents(G, p, src_id, dst_id)
190+
raise Unbounded("Negative cycle detected.")
191+
A, is_negative, has_negative_diagonal = G.get_properties(
192+
"offdiag has_negative_edges- has_negative_diagonal"
193+
)
194+
if A.dtype == bool:
195+
# Should we upcast e.g. INT8 to INT64 as well?
196+
dtype = int
197+
else:
198+
dtype = A.dtype
199+
cutoff = None
200+
n = A.nrows
201+
d = Vector(dtype, n, name="bellman_ford_path_length")
202+
d[src_id] = 0
203+
p = Vector(int, n, name="bellman_ford_path_parent")
204+
p[src_id] = src_id
205+
206+
prev = d.dup(name="prev")
207+
cur = Vector(dtype, n, name="cur")
208+
indices = Vector(int, n, name="indices")
209+
mask = Vector(bool, n, name="mask")
210+
B = Matrix(dtype, n, n, name="B")
211+
Indices = Matrix(int, n, n, name="Indices")
212+
cols = prev.to_coo(values=False)[0]
213+
one = unary.one[bool]
214+
for _i in range(n - 1):
215+
# This is a slightly modified Bellman-Ford algorithm.
216+
# `cur` is the current frontier of values that improved in the previous iteration.
217+
# This means that in this iteration we drop values from `cur` that are not better.
218+
cur << min_plus(prev @ A)
219+
if cutoff is not None:
220+
cur << select.valuele(cur, cutoff)
221+
222+
# Mask is True where cur not in d or cur < d
223+
mask << one(cur)
224+
mask(binary.second) << binary.lt(cur & d)
225+
226+
# Drop values from `cur` that didn't improve
227+
cur(mask.V, replace) << cur
228+
if cur.nvals == 0:
229+
break
230+
# Update `d` with values that improved
231+
d(cur.S) << cur
232+
if not is_negative:
233+
# Limit exploration if we have a target
234+
cutoff = cur.get(dst_id, cutoff)
235+
236+
# Now try to find the parents!
237+
# This is also not standard. Typically, UDTs and UDFs are used to keep
238+
# track of both the minimum element and the parent id at the same time.
239+
# Only include rows and columns that were used this iteration.
240+
rows = cols
241+
cols = cur.to_coo(values=False)[0]
242+
B.clear()
243+
B[rows, cols] = A[rows, cols]
244+
245+
# Reverse engineer to determine parent
246+
B << binary.plus(prev & B)
247+
B << binary.iseq(B & cur)
248+
B << select.valuene(B, False)
249+
Indices << indexunary.rowindex(B)
250+
indices << Indices.reduce_columnwise(monoid.min)
251+
p(indices.S) << indices
252+
prev, cur = cur, prev
253+
else:
254+
# Check for negative cycle when for loop completes without breaking
255+
cur << min_plus(prev @ A)
256+
if cutoff is not None:
257+
cur << select.valuele(cur, cutoff)
258+
mask << binary.lt(cur & d)
259+
if mask.get(dst_id):
260+
raise Unbounded("Negative cycle detected.")
261+
path = _reconstruct_path_from_parents(G, p, src_id, dst_id)
262+
if has_negative_diagonal and path:
263+
mask.clear()
264+
mask[G.list_to_ids(path)] = True
265+
diag = G.get_property("diag", mask=mask.S)
266+
if diag.nvals > 0:
267+
raise Unbounded("Negative cycle detected.")
268+
mask << binary.first(mask & cur) # mask(cur.S, replace) << mask
269+
if mask.nvals > 0:
270+
# Is there a path from any visited node with negative self-loop to target?
271+
# We could actually stop as soon as any from `path` is visited
272+
indices, _ = mask.to_coo(values=False)[0]
273+
q = _bfs_plain(G, target=target, index=indices, cutoff=_i)
274+
if dst_id in q:
275+
raise Unbounded("Negative cycle detected.")
276+
return path
277+
278+
167279
def negative_edge_cycle(G):
168280
# TODO: use a heuristic to try to stop early
169281
if G.is_directed():

Diff for: graphblas_algorithms/generators/ego.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ..algorithms.components.connected import _plain_bfs
1+
from ..algorithms.components.connected import _bfs_plain
22
from ..algorithms.shortest_paths.weighted import single_source_bellman_ford_path_length
33

44
__all__ = ["ego_graph"]
@@ -14,7 +14,7 @@ def ego_graph(G, n, radius=1, center=True, undirected=False, is_weighted=False):
1414
if is_weighted:
1515
v = single_source_bellman_ford_path_length(G2, n, cutoff=radius)
1616
else:
17-
v = _plain_bfs(G2, n, cutoff=radius)
17+
v = _bfs_plain(G2, n, cutoff=radius)
1818
if not center:
1919
del v[G._key_to_id[n]]
2020

Diff for: graphblas_algorithms/interface.py

+1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ class Dispatcher:
8383
nxapi.shortest_paths.unweighted.single_target_shortest_path_length
8484
)
8585
all_pairs_shortest_path_length = nxapi.shortest_paths.unweighted.all_pairs_shortest_path_length
86+
bellman_ford_path = nxapi.shortest_paths.weighted.bellman_ford_path
8687
all_pairs_bellman_ford_path_length = (
8788
nxapi.shortest_paths.weighted.all_pairs_bellman_ford_path_length
8889
)

Diff for: graphblas_algorithms/nxapi/shortest_paths/weighted.py

+10
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
__all__ = [
88
"all_pairs_bellman_ford_path_length",
9+
"bellman_ford_path",
910
"negative_edge_cycle",
1011
"single_source_bellman_ford_path_length",
1112
]
@@ -55,6 +56,15 @@ def single_source_bellman_ford_path_length(G, source, weight="weight"):
5556
return G.vector_to_nodemap(d)
5657

5758

59+
def bellman_ford_path(G, source, target, weight="weight"):
60+
# TODO: what if weight is a function?
61+
G = to_graph(G, weight=weight)
62+
try:
63+
return algorithms.bellman_ford_path(G, source, target)
64+
except KeyError as e:
65+
raise NodeNotFound(*e.args) from e
66+
67+
5868
def negative_edge_cycle(G, weight="weight", heuristic=True):
5969
# TODO: what if weight is a function?
6070
# TODO: use a heuristic to try to stop early

Diff for: graphblas_algorithms/tests/test_match_nx.py

+19
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"""
1212
import sys
1313
from collections import namedtuple
14+
from pathlib import Path
1415

1516
import pytest
1617

@@ -191,3 +192,21 @@ def test_print_dispatched_implemented(nx_names_to_info, gb_names_to_info):
191192
for i, name in enumerate(sorted(fullnames)):
192193
print(i, name)
193194
print("=============================================================================")
195+
196+
197+
def test_algorithms_in_readme(nx_names_to_info, gb_names_to_info):
198+
"""Ensure all algorithms are mentioned in README.md."""
199+
implemented = nx_names_to_info.keys() & gb_names_to_info.keys()
200+
path = Path(__file__).parent.parent.parent / "README.md"
201+
if not path.exists():
202+
return
203+
with path.open("r") as f:
204+
text = f.read()
205+
missing = set()
206+
for name in sorted(implemented):
207+
if name not in text:
208+
missing.add(name)
209+
if missing:
210+
msg = f"Algorithms missing in README.md: {', '.join(sorted(missing))}"
211+
print(msg)
212+
raise AssertionError(msg)

0 commit comments

Comments
 (0)