Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Graph IO to classic weights file formats #698

Merged
merged 8 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/312-dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ dependencies:
# testing
- codecov
- matplotlib
- tobler
- h3-py
- pytest
- pytest-cov
Expand Down Expand Up @@ -39,4 +38,5 @@ dependencies:
- xarray
- git+https://github.com/geopandas/geopandas.git@main
- git+https://github.com/shapely/shapely.git@main
- git+https://github.com/pysal/tobler.git@main
- pulp
2 changes: 1 addition & 1 deletion libpysal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import contextlib
from importlib.metadata import PackageNotFoundError, version

from . import cg, examples, io, weights
from . import cg, examples, graph, io, weights

with contextlib.suppress(PackageNotFoundError):
__version__ = version("libpysal")
2 changes: 1 addition & 1 deletion libpysal/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .base import Graph, read_parquet # noqa
from .base import Graph, read_parquet, read_gal, read_gwt # noqa
8 changes: 6 additions & 2 deletions libpysal/graph/_contiguity.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,12 @@ def _vertex_set_intersection(geoms, rook=True, ids=None, by_perimeter=False):
nexus_names = {ids[ix] for ix in nexus}
for geom_ix in nexus:
gid = ids[geom_ix]
graph[gid] |= nexus_names
graph[gid].remove(gid)
graph[gid].update(nexus_names)

for idx in ids:
graph[idx].remove(idx)

# return graph
martinfleis marked this conversation as resolved.
Show resolved Hide resolved

heads, tails, weights = _neighbor_dict_to_edges(graph)

Expand Down
18 changes: 10 additions & 8 deletions libpysal/graph/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,15 @@ def _neighbor_dict_to_edges(neighbors, weights=None):
that the any self-loops have a weight of zero.
"""
idxs = pd.Series(neighbors).explode()
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
"Downcasting object dtype arrays on .fillna, .ffill, .bfill ",
FutureWarning,
)
idxs = idxs.fillna(pd.Series(idxs.index, index=idxs.index)) # self-loops
isolates = idxs.isna()
if isolates.any():
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
"Downcasting object dtype arrays on .fillna, .ffill, .bfill ",
FutureWarning,
)
idxs = idxs.fillna(pd.Series(idxs.index, index=idxs.index)) # self-loops
heads, tails = idxs.index.values, idxs.values
tails = tails.astype(heads.dtype)
if weights is not None:
Expand All @@ -130,7 +132,7 @@ def _neighbor_dict_to_edges(neighbors, weights=None):
data_array = pd.to_numeric(data_array)
else:
data_array = np.ones(idxs.shape[0], dtype=int)
data_array[heads == tails] = 0
data_array[isolates.values] = 0
return heads, tails, data_array


Expand Down
80 changes: 79 additions & 1 deletion libpysal/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from ._indices import _build_from_h3
from ._kernel import _distance_band, _kernel
from ._matching import _spatial_matching
from ._parquet import _read_parquet, _to_parquet
from ._plotting import _explore_graph, _plot
from ._set_ops import SetOpsMixin
from ._spatial_lag import _lag_spatial
Expand All @@ -30,6 +29,9 @@
_resolve_islands,
_sparse_to_arrays,
)
from .io._gal import _read_gal, _to_gal
from .io._gwt import _read_gwt, _to_gwt
from .io._parquet import _read_parquet, _to_parquet

ALLOWED_TRANSFORMATIONS = ("O", "B", "R", "D", "V", "C")

Expand All @@ -41,6 +43,13 @@
Levi John Wolf (levi.john.wolf@gmail.com)
"""

__all__ = [
"Graph",
"read_parquet",
"read_gal",
"read_gwt",
]


class Graph(SetOpsMixin):
"""Graph class encoding spatial weights matrices
Expand Down Expand Up @@ -1644,6 +1653,38 @@ def to_parquet(self, path, **kwargs):
"""
_to_parquet(self, path, **kwargs)

def to_gal(self, path):
"""Save Graph to a GAL file

Graph is serialized to the GAL file format.

Parameters
----------
path : str
path to the GAL file

See also
--------
read_gal
"""
_to_gal(self, path)

def to_gwt(self, path):
"""Save Graph to a GWT file

Graph is serialized to the GWT file format.

Parameters
----------
path : str
path to the GWT file

See also
--------
read_gwt
"""
_to_gwt(self, path)

def to_networkx(self):
"""Convert Graph to a ``networkx`` graph.

Expand Down Expand Up @@ -1999,3 +2040,40 @@ def read_parquet(path, **kwargs):
"""
adjacency, transformation = _read_parquet(path, **kwargs)
return Graph(adjacency, transformation, is_sorted=True)


def read_gal(path):
"""Read Graph from a GAL file

The reader tries to infer the dtype of IDs. In case of unsuccessful
casting to int, it will fall back to string.

Parameters
----------
path : str
path to a file

Returns
-------
Graph
deserialized Graph
"""
neighbors = _read_gal(path)
return Graph.from_dicts(neighbors)


def read_gwt(path):
"""Read Graph from a GWT file

Parameters
----------
path : str
path to a file

Returns
-------
Graph
deserialized Graph
"""
head, tail, weight = _read_gwt(path)
return Graph.from_arrays(head, tail, weight)
64 changes: 64 additions & 0 deletions libpysal/graph/io/_gal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import contextlib


def _read_gal(path):
"""Read GAL weights to Graph object

Parameters
----------
path : str
path to GAL file

Returns
-------
dict
neighbors dict
"""
with open(path) as file:
neighbors = {}

# handle case where more than n is specified in first line
header = file.readline().strip().split()
header_n = len(header)
n = int(header[0])

if header_n > 1:
n = int(header[1])

Check warning on line 26 in libpysal/graph/io/_gal.py

View check run for this annotation

Codecov / codecov/patch

libpysal/graph/io/_gal.py#L26

Added line #L26 was not covered by tests

for _ in range(n):
id_, _ = file.readline().strip().split()
neighbors_i = file.readline().strip().split()
neighbors[id_] = neighbors_i

# try casting to ints to ensure loss-less roundtrip of integer node ids
with contextlib.suppress(ValueError):
neighbors = {int(k): list(map(int, v)) for k, v in neighbors.items()}

return neighbors


def _to_gal(graph_obj, path):
"""Write GAL weights to Graph object

Parameters
----------
graph_obj : Graph
Graph object
path : str
path to GAL file
"""
grouper = graph_obj._adjacency.groupby(level=0, sort=False)

with open(path, "w") as file:
file.write(f"{graph_obj.n}\n")

for ix, chunk in grouper:
if ix in graph_obj.isolates:
neighbors = []
else:
neighbors = (
chunk.index.get_level_values("neighbor").astype(str).tolist()
)

file.write(f"{ix} {len(neighbors)}\n")
file.write(" ".join(neighbors) + "\n")
38 changes: 38 additions & 0 deletions libpysal/graph/io/_gwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pandas as pd


def _read_gwt(path):
"""
Read GWT weights to Graph object

Parameters
----------
path : str
path to GWT file

Returns
-------
tuple
focal, neighbor, weight arrays
"""
adjacency = pd.read_csv(path, sep=r"\s+", skiprows=1, header=None)
return adjacency[0].values, adjacency[1].values, adjacency[2].values


def _to_gwt(graph_obj, path):
"""
Write GWT weights to Graph object

Parameters
----------
graph_obj : Graph
Graph object
path : str
path to GAL file
"""
adj = graph_obj._adjacency.reset_index()
adj["focal"] = adj["focal"].astype(str).str.replace(" ", "_")
adj["neighbor"] = adj["neighbor"].astype(str).str.replace(" ", "_")
with open(path, "w") as file:
file.write(f"0 {graph_obj.n} Unknown Unknown\n")
adj.to_csv(path, sep=" ", header=False, index=False, mode="a", float_format="%.7f")
File renamed without changes.
26 changes: 26 additions & 0 deletions libpysal/graph/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,32 @@ def test_parquet(self):
g_pandas = graph.read_parquet(path)
assert self.g_str == g_pandas

def test_gal(self):
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "g_int.gal")
g_int = self.g_int.transform("b")
g_int.to_gal(path)
g_int_ = graph.read_gal(path)
assert g_int == g_int_

path = os.path.join(tmpdir, "g_str.gal")
g_str = self.g_str.transform("b")
g_str.to_gal(path)
g_str_ = graph.read_gal(path)
assert g_str == g_str_

def test_gwt(self):
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "g_int.gwt")
self.g_int.to_gwt(path)
g_int = graph.read_gwt(path)
assert self.g_int == g_int

path = os.path.join(tmpdir, "g_str.gwt")
self.g_str.to_gwt(path)
g_str = graph.read_gwt(path)
assert self.g_str == g_str

def test_getitem(self):
expected = pd.Series(
[1, 0.5, 0.5],
Expand Down