Skip to content

Commit

Permalink
refactor element_as_gdf internals (#724)
Browse files Browse the repository at this point in the history
* refactor element_as_gdf internal constructors

* deprecate geom_col kwarg in elem_as_gdf

* update element_as_gdf() stacklevel
  • Loading branch information
jGaboardi committed May 24, 2023
1 parent b1feadd commit c25f662
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 38 deletions.
33 changes: 28 additions & 5 deletions spaghetti/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3039,7 +3039,7 @@ def element_as_gdf(
snapped=False,
routes=None,
id_col="id",
geom_col="geometry",
geom_col=None,
):
"""Return a ``geopandas.GeoDataFrame`` of network elements. This can be
(a) the vertices of a network; (b) the arcs of a network; (c) both the
Expand Down Expand Up @@ -3068,8 +3068,9 @@ def element_as_gdf(
``geopandas.GeoDataFrame`` column name for IDs. Default is ``"id"``.
When extracting routes this creates an (origin, destination) tuple.
geom_col : str
``geopandas.GeoDataFrame`` column name for geometry. Default is
``"geometry"``.
Deprecated and will be removed in the minor release.
``geopandas.GeoDataFrame`` column name for IDs. Default is ``"id"``.
When extracting routes this creates an (origin, destination) tuple.
Raises
------
Expand Down Expand Up @@ -3142,9 +3143,22 @@ def element_as_gdf(
"""

# see GH#722
if geom_col:
dep_msg = (
"The ``geom_col`` keyword argument is deprecated and will "
"be dropped in the next minor release of pysal/spaghetti (1.8.0) "
"in favor of the default 'geometry' name. Users can rename "
"the geometry column following processing, if desired."
)
warnings.warn(dep_msg, FutureWarning, stacklevel=2)

# shortest path routes between observations
if routes:
paths = util._routes_as_gdf(routes, id_col)
# see GH#722
if geom_col:
paths.rename_geometry(geom_col, inplace=True)
return paths

# need vertices place holder to create network segment LineStrings
Expand All @@ -3162,21 +3176,30 @@ def element_as_gdf(
pp_name,
snapped,
id_col=id_col,
geom_col=geom_col,
)

# return points geodataframe if arcs not specified or
# if extracting `PointPattern` points
if not arcs or pp_name:
# see GH#722
if geom_col:
points.rename_geometry(geom_col, inplace=True)
return points

# arcs
arcs = util._arcs_as_gdf(net, points, id_col=id_col, geom_col=geom_col)
arcs = util._arcs_as_gdf(net, points, id_col=id_col)

if vertices_for_arcs:
# see GH#722
if geom_col:
arcs.rename_geometry(geom_col, inplace=True)
return arcs

else:
# see GH#722
if geom_col:
points.rename_geometry(geom_col, inplace=True)
arcs.rename_geometry(geom_col, inplace=True)
return points, arcs


Expand Down
29 changes: 28 additions & 1 deletion spaghetti/tests/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,12 +464,31 @@ def test_element_as_gdf(self):
observed_arc_wkt = observed_arc.wkt
assert observed_arc_wkt == known_arc_wkt

# See GH#722
with pytest.warns(
FutureWarning, match="The ``geom_col`` keyword argument is deprecated"
):
geom_col = "MY_GEOM"
vertices, arcs = spaghetti.element_as_gdf(
self.ntw_shp, vertices=True, arcs=True, geom_col=geom_col
)
assert vertices.geometry.name == geom_col
assert arcs.geometry.name == geom_col

# extract only arcs
arcs = spaghetti.element_as_gdf(self.ntw_shp, arcs=True)
observed_arc = arcs.loc[(arcs["id"] == (0, 1)), "geometry"].squeeze()
observed_arc_wkt = observed_arc.wkt
assert observed_arc_wkt == known_arc_wkt

# See GH#722
with pytest.warns(
FutureWarning, match="The ``geom_col`` keyword argument is deprecated"
):
geom_col = "MY_GEOM"
arcs = spaghetti.element_as_gdf(self.ntw_shp, arcs=True, geom_col=geom_col)
assert arcs.geometry.name == geom_col

# extract symmetric routes
known_length, bounds, h, v = 2.6, (0, 0, 3, 3), 2, 2
lattice = spaghetti.regular_lattice(bounds, h, nv=v, exterior=False)
Expand All @@ -493,9 +512,17 @@ def test_element_as_gdf(self):
_, tree = ntw.allneighbordistances(points1, points2, gen_tree=True)
paths = ntw.shortest_paths(tree, points1, pp_dest=points2)
paths_gdf = spaghetti.element_as_gdf(ntw, routes=paths)
observed_origins = paths_gdf["O"].nunique()
observed_origins = paths_gdf["id"].map(lambda x: x[0]).nunique()
assert observed_origins == known_origins

# See GH#722
with pytest.warns(
FutureWarning, match="The ``geom_col`` keyword argument is deprecated"
):
geom_col = "MY_GEOM"
paths_gdf = spaghetti.element_as_gdf(ntw, routes=paths, geom_col=geom_col)
assert paths_gdf.geometry.name == geom_col

def test_regular_lattice(self):
# 4x4 regular lattice with the exterior
known = [P00, P10]
Expand Down
47 changes: 15 additions & 32 deletions spaghetti/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

try:
import geopandas
import pandas
import shapely
from shapely.geometry import LineString
except ImportError:
Expand Down Expand Up @@ -654,15 +655,7 @@ def build_chains(space_h, space_v, exterior, bounds, h=True):


@requires("geopandas", "shapely")
def _points_as_gdf(
net,
vertices,
vertices_for_arcs,
pp_name,
snapped,
id_col=None,
geom_col=None, # noqa ARG001
):
def _points_as_gdf(net, vertices, vertices_for_arcs, pp_name, snapped, id_col=None):
"""Internal function for returning a point ``geopandas.GeoDataFrame``
called from within ``spaghetti.element_as_gdf()``.
Expand Down Expand Up @@ -744,7 +737,7 @@ def _points_as_gdf(


@requires("geopandas", "shapely")
def _arcs_as_gdf(net, points, id_col=None, geom_col=None):
def _arcs_as_gdf(net, points, id_col=None):
"""Internal function for returning an arc ``geopandas.GeoDataFrame``
called from within ``spaghetti.element_as_gdf()``.
Expand All @@ -763,19 +756,17 @@ def _arcs_as_gdf(net, points, id_col=None, geom_col=None):
"""

# arcs
arcs = {}

# iterate over network arcs
for vtx1_id, vtx2_id in net.arcs:
# extract vertices comprising the network arc
vtx1 = points.loc[(points[id_col] == vtx1_id), geom_col].squeeze()
vtx2 = points.loc[(points[id_col] == vtx2_id), geom_col].squeeze()
# create a LineString for the network arc
arcs[(vtx1_id, vtx2_id)] = LineString((vtx1, vtx2))
def _line_coords(loc):
return (
(points.loc[loc[0]].geometry.x, points.loc[loc[0]].geometry.y),
(points.loc[loc[1]].geometry.x, points.loc[loc[1]].geometry.y),
)

# instantiate GeoDataFrame
arcs = geopandas.GeoDataFrame(sorted(arcs.items()), columns=[id_col, geom_col])
arcs = pandas.DataFrame(zip(sorted(net.arcs)), columns=[id_col])
arcs = arcs.set_geometry(
shapely.linestrings(arcs[id_col].map(_line_coords).values.tolist())
)

# additional columns
if hasattr(net, "network_component_labels"):
Expand Down Expand Up @@ -806,17 +797,9 @@ def _routes_as_gdf(paths, id_col):
"""

# isolate the origins, destinations, and geometries
origs = [o for (o, d), g in paths]
dests = [d for (o, d), g in paths]
geoms = [LineString(g.vertices) for (o, d), g in paths]

# instantiate as a geodataframe
paths = geopandas.GeoDataFrame(geometry=geoms)
paths["O"] = origs
paths["D"] = dests

if id_col:
paths[id_col] = paths.apply(lambda x: (x["O"], x["D"]), axis=1)
paths = dict(paths)
ids, geoms = zip(paths.keys()), [LineString(g.vertices) for g in paths.values()]
paths = geopandas.GeoDataFrame(ids, columns=[id_col], geometry=geoms)

return paths

0 comments on commit c25f662

Please sign in to comment.