diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index b1084fb1..c89cc0bf 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -2269,6 +2269,8 @@ def get_associated_variable_names( if grid := attrs_or_encoding.get("grid", None): coords["grid"] = [grid] + if isinstance(self._obj, Dataset): + coords["coordinates"].extend(sgrid.get_topology_coords(self._obj, grid)) if grid_mapping_attr := attrs_or_encoding.get("grid_mapping", None): # Parse grid mapping variables and their coordinates diff --git a/cf_xarray/sgrid.py b/cf_xarray/sgrid.py index 4773f047..54e12aa8 100644 --- a/cf_xarray/sgrid.py +++ b/cf_xarray/sgrid.py @@ -11,6 +11,31 @@ # "edge3_dimensions", ] +SGRID_COORD_ATTRS = [ + "node_coordinates", + "face_coordinates", + "edge1_coordinates", + "edge2_coordinates", + "volume_coordinates", +] + + +def get_topology_coords(ds, grid_var_name): + """Return coordinate variable names referenced by an SGRID topology variable. + + Reads ``node_coordinates``, ``face_coordinates``, ``edge{1,2}_coordinates``, + and ``volume_coordinates`` from the topology variable's attrs and filters + to names that are actually present in ``ds``. + """ + if grid_var_name not in ds.variables: + return [] + grid_attrs = ds[grid_var_name].attrs + names: list[str] = [] + for attr_name in SGRID_COORD_ATTRS: + if coord_str := grid_attrs.get(attr_name): + names.extend(n for n in coord_str.split() if n in ds.variables) + return names + def parse_axes(ds): import re diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index a188aa39..d1a06bbd 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -2484,6 +2484,43 @@ def test_sgrid(): } +def test_sgrid_includes_topology_coordinates(): + """Variables referenced in node/face/edge/volume_coordinates of the + grid_topology variable should be pulled in by ds.cf[[var]].""" + roms = sgrid_roms.copy() + for pos in ("psi", "rho", "u", "v"): + roms[f"lon_{pos}"] = ((f"xi_{pos}", f"eta_{pos}"), np.zeros((2, 2))) + roms[f"lat_{pos}"] = ((f"xi_{pos}", f"eta_{pos}"), np.zeros((2, 2))) + + expected_coord_vars = { + "lon_psi", + "lat_psi", + "lon_rho", + "lat_rho", + "lon_u", + "lat_u", + "lon_v", + "lat_v", + } + assoc = roms.cf.get_associated_variable_names("u") + assert expected_coord_vars.issubset(set(assoc["coordinates"])) + + subset = roms.cf[["u"]] + assert "grid" in subset.variables + assert expected_coord_vars.issubset(set(subset.variables)) + + # only dim-compatible coords attach to the DataArray form + u_da = roms.cf["u"] + assert {"lon_u", "lat_u"}.issubset(set(u_da.coords)) + + delft = sgrid_delft.copy() + delft["node_lon"] = (("inode", "jnode"), np.zeros((2, 2))) + delft["node_lat"] = (("inode", "jnode"), np.zeros((2, 2))) + delft["foo"] = (("icell", "jcell"), np.ones((2, 2)), {"grid": "grid"}) + delft_subset = delft.cf[["foo"]] + assert {"grid", "node_lon", "node_lat"}.issubset(set(delft_subset.variables)) + + def test_ancillary_variables_extra_dim(): ds = xr.Dataset( { diff --git a/doc/sgrid_ugrid.md b/doc/sgrid_ugrid.md index 313aee1b..ff0417e2 100644 --- a/doc/sgrid_ugrid.md +++ b/doc/sgrid_ugrid.md @@ -61,6 +61,23 @@ only `xi_u`, `eta_u` are listed in the repr even though the attributes on the `g variable `grid` list many more dimension names. ``` +### Coordinate variables + +`cf_xarray` also follows the `node_coordinates`, `face_coordinates`, +`edge1_coordinates`, `edge2_coordinates`, and `volume_coordinates` attributes +on the `grid_topology` variable. When you select a data variable that +references a `grid_topology` via its `grid` attribute, the referenced +coordinate variables are pulled in alongside it: + +```python +ds.cf[["u"]] # includes `grid`, lon_psi/lat_psi, lon_rho/lat_rho, ... +``` + +Only names actually present in the dataset are propagated. For the +`DataArray` form (`ds.cf["u"]`) xarray only attaches coordinates whose +dimensions are compatible with the variable, so e.g. only `lon_u`/`lat_u` +appear as coords on `u`. + ## UGRID ### Topology variable