From b33698aacc4b030cd8345b64590bebcbb4bc21bf Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 28 Aug 2025 18:09:13 -0600 Subject: [PATCH 01/14] Support multiple grid mappings on a DataArray --- cf_xarray/accessor.py | 151 +++++++++++++++++++++++++++---- cf_xarray/datasets.py | 34 +++++++ cf_xarray/tests/test_accessor.py | 61 ++++++++++++- 3 files changed, 224 insertions(+), 22 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 52d525b5..3691f535 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -440,6 +440,32 @@ def _get_bounds(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]: return list(results) +def _parse_grid_mapping_attribute(grid_mapping_attr: str) -> list[str]: + """ + Parse a grid_mapping attribute that may contain multiple grid mappings. + + The attribute has the format: "grid_mapping_variable_name: optional_coordinate_names_space_separated" + Multiple sections are separated by colons. + + Examples: + - Single: "spatial_ref" + - Multiple: "spatial_ref: crs_4326: latitude longitude crs_27700: x27700 y27700" + + Returns a list of grid mapping variable names. + """ + # Check if there are colons indicating multiple mappings + if ":" not in grid_mapping_attr: + return [grid_mapping_attr.strip()] + + # Use regex to find grid mapping variable names + # Pattern matches: word at start OR word that comes after some coordinate names and before ":" + # This handles cases like "spatial_ref: crs_4326: latitude longitude crs_27700: x27700 y27700" + pattern = r"(?:^|\s)([a-zA-Z_][a-zA-Z0-9_]*)(?=\s*:)" + matches = re.findall(pattern, grid_mapping_attr) + + return matches if matches else [grid_mapping_attr.strip()] + + def _get_grid_mapping_name(obj: DataArray | Dataset, key: str) -> list[str]: """ Translate from grid mapping name attribute to appropriate variable name. @@ -467,13 +493,17 @@ def _get_grid_mapping_name(obj: DataArray | Dataset, key: str) -> list[str]: for var in variables.values(): attrs_or_encoding = ChainMap(var.attrs, var.encoding) if "grid_mapping" in attrs_or_encoding: - grid_mapping_var_name = attrs_or_encoding["grid_mapping"] - if grid_mapping_var_name not in variables: - raise ValueError( - f"{var} defines non-existing grid_mapping variable {grid_mapping_var_name}." - ) - if key == variables[grid_mapping_var_name].attrs["grid_mapping_name"]: - results.update([grid_mapping_var_name]) + grid_mapping_attr = attrs_or_encoding["grid_mapping"] + # Parse potentially multiple grid mappings + grid_mapping_var_names = _parse_grid_mapping_attribute(grid_mapping_attr) + + for grid_mapping_var_name in grid_mapping_var_names: + if grid_mapping_var_name not in variables: + raise ValueError( + f"{var} defines non-existing grid_mapping variable {grid_mapping_var_name}." + ) + if key == variables[grid_mapping_var_name].attrs["grid_mapping_name"]: + results.update([grid_mapping_var_name]) return list(results) @@ -1943,9 +1973,34 @@ def get_associated_variable_names( if dbounds := self._obj[dim].attrs.get("bounds", None): coords["bounds"].append(dbounds) - for attrname in ["grid", "grid_mapping"]: - if maybe := attrs_or_encoding.get(attrname, None): - coords[attrname] = [maybe] + if grid := attrs_or_encoding.get("grid", None): + coords["grid"] = [grid] + + if grid_mapping_attr := attrs_or_encoding.get("grid_mapping", None): + # Parse grid mapping variables using the same function + grid_mapping_vars = _parse_grid_mapping_attribute(grid_mapping_attr) + coords["grid_mapping"] = grid_mapping_vars + + # Extract coordinate variables using regex + if ":" in grid_mapping_attr: + # Pattern to find coordinate variables: words that come after ":" but before next grid mapping variable + # This captures coordinate variables between grid mapping sections + coord_pattern = r":\s+([^:]+?)(?=\s+[a-zA-Z_][a-zA-Z0-9_]*\s*:|$)" + coord_matches = re.findall(coord_pattern, grid_mapping_attr) + + for coord_section in coord_matches: + # Split each coordinate section and add valid coordinate names + coord_vars = coord_section.split() + # Filter out grid mapping variable names that might have been captured + coord_vars = [ + var + for var in coord_vars + if not ( + var.startswith(("crs_", "spatial_", "proj_")) + and var in grid_mapping_vars + ) + ] + coords["coordinates"].extend(coord_vars) more: Sequence[Hashable] = () if geometry_var := attrs_or_encoding.get("geometry", None): @@ -2899,6 +2954,60 @@ def formula_terms(self) -> dict[str, str]: # numpydoc ignore=SS06 terms[key] = value return terms + @property + def grid_mapping_names(self) -> dict[str, list[str]]: + """ + Mapping the CF grid mapping name to the grid mapping variable name. + + Returns + ------- + dict + Dictionary mapping the CF grid mapping name to the variable name containing + the grid mapping attributes. + + See Also + -------- + DataArray.cf.grid_mapping_name + Dataset.cf.grid_mapping_names + + References + ---------- + https://cfconventions.org/Data/cf-conventions/cf-conventions-1.10/cf-conventions.html#appendix-grid-mappings + + Examples + -------- + >>> from cf_xarray.datasets import hrrrds + >>> hrrrds.foo.cf.grid_mapping_names + {'latitude_longitude': ['crs_4326'], 'lambert_azimuthal_equal_area': ['spatial_ref']} + """ + da = self._obj + attrs_or_encoding = ChainMap(da.attrs, da.encoding) + grid_mapping_attr = attrs_or_encoding.get("grid_mapping", None) + + if not grid_mapping_attr: + return {} + + # Parse potentially multiple grid mappings + grid_mapping_var_names = _parse_grid_mapping_attribute(grid_mapping_attr) + + results = {} + for grid_mapping_var_name in grid_mapping_var_names: + # First check if it's in the DataArray's coords (for multiple grid mappings + # that are coordinates of the DataArray) + if grid_mapping_var_name in da.coords: + grid_mapping_var = da.coords[grid_mapping_var_name] + if "grid_mapping_name" in grid_mapping_var.attrs: + gmn = grid_mapping_var.attrs["grid_mapping_name"] + if gmn not in results: + results[gmn] = [grid_mapping_var_name] + else: + results[gmn].append(grid_mapping_var_name) + # For standalone DataArrays, the grid mapping variables may not be available + # This is a limitation of the xarray data model - when you extract a DataArray + # from a Dataset, it doesn't carry over non-coordinate variables + + return results + @property def grid_mapping_name(self) -> str: """ @@ -2911,6 +3020,7 @@ def grid_mapping_name(self) -> str: See Also -------- + DataArray.cf.grid_mapping_names Dataset.cf.grid_mapping_names Examples @@ -2920,19 +3030,22 @@ def grid_mapping_name(self) -> str: 'rotated_latitude_longitude' """ - da = self._obj + # Use grid_mapping_names under the hood + grid_mapping_names = self.grid_mapping_names - attrs_or_encoding = ChainMap(da.attrs, da.encoding) - grid_mapping = attrs_or_encoding.get("grid_mapping", None) - if not grid_mapping: + if not grid_mapping_names: raise ValueError("No 'grid_mapping' attribute present.") - if grid_mapping not in da._coords: - raise ValueError(f"Grid Mapping variable {grid_mapping} not present.") - - grid_mapping_var = da[grid_mapping] + if len(grid_mapping_names) > 1: + # Get the variable names for error message + all_vars = list(itertools.chain.from_iterable(grid_mapping_names.values())) + raise ValueError( + f"Multiple grid mappings found: {all_vars}. " + "Please use DataArray.cf.grid_mapping_names instead." + ) - return grid_mapping_var.attrs["grid_mapping_name"] + # Return the single grid mapping name + return next(iter(grid_mapping_names.keys())) def __getitem__(self, key: Hashable | Iterable[Hashable]) -> DataArray: """ diff --git a/cf_xarray/datasets.py b/cf_xarray/datasets.py index 32955f25..409a1afe 100644 --- a/cf_xarray/datasets.py +++ b/cf_xarray/datasets.py @@ -1,5 +1,6 @@ import numpy as np import xarray as xr +from pyproj import CRS airds = xr.tutorial.open_dataset("air_temperature").isel(time=slice(4), lon=slice(50)) airds.air.attrs["cell_measures"] = "area: cell_area" @@ -750,6 +751,39 @@ def _create_inexact_bounds(): ) +hrrrds = xr.Dataset() +hrrrds["foo"] = ( + ("x", "y"), + np.arange(200).reshape((10, 20)), + { + "grid_mapping": "spatial_ref: crs_4326: latitude longitude crs_27700: x27700 y27700" + }, +) +hrrrds.coords["spatial_ref"] = ((), 0, CRS.from_epsg(3035).to_cf()) +hrrrds.coords["crs_4326"] = ((), 0, CRS.from_epsg(4326).to_cf()) +hrrrds.coords["crs_27700"] = ((), 0, CRS.from_epsg(27700).to_cf()) +hrrrds.coords["latitude"] = ( + ("x", "y"), + np.ones((10, 20)), + {"standard_name": "latitude"}, +) +hrrrds.coords["longitude"] = ( + ("x", "y"), + np.zeros((10, 20)), + {"standard_name": "longitude"}, +) +hrrrds.coords["y27700"] = ( + ("x", "y"), + np.ones((10, 20)), + {"standard_name": "projected_x_coordinate"}, +) +hrrrds.coords["x27700"] = ( + ("x", "y"), + np.zeros((10, 20)), + {"standard_name": "projected_y_coordinate"}, +) + + def point_dataset(): from shapely.geometry import MultiPoint, Point diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index c3f7005a..eac7b0d8 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -29,6 +29,7 @@ flag_indep_uint16, flag_mix, forecast, + hrrrds, mollwds, multiple, popds, @@ -1018,9 +1019,7 @@ def test_grid_mappings(): assert_identical(actual, expected) # not properly propagated if grid mapping variable not in coords - with pytest.raises( - ValueError, match="Grid Mapping variable rotated_pole not present." - ): + with pytest.raises(ValueError, match="No 'grid_mapping' attribute present."): ds.temp.cf.grid_mapping_name # check for https://github.com/xarray-contrib/cf-xarray/issues/448 @@ -1068,6 +1067,62 @@ def test_grid_mappings(): assert "rotated_pole" in ds.coords +def test_multiple_grid_mapping_attribute(): + ds = hrrrds + + # Test Dataset grid_mapping_names + # Now includes British National Grid (EPSG:27700) which has grid_mapping_name + assert ds.cf.grid_mapping_names == { + "latitude_longitude": ["crs_4326"], + "lambert_azimuthal_equal_area": ["spatial_ref"], + "transverse_mercator": ["crs_27700"], + } + + # Test DataArray grid_mapping_names + da = ds.foo + # Now with improved regex parsing, all 3 grid mappings should be detected + assert da.cf.grid_mapping_names == { + "latitude_longitude": ["crs_4326"], + "lambert_azimuthal_equal_area": ["spatial_ref"], + "transverse_mercator": ["crs_27700"], + } + + # Test that grid_mapping_name raises an error with multiple mappings + with pytest.raises( + ValueError, + match="Multiple grid mappings found.*Please use DataArray.cf.grid_mapping_names", + ): + da.cf.grid_mapping_name + + assert "crs_4326" in ds.cf["foo"].coords + assert "spatial_ref" in ds.cf["foo"].coords + assert "crs_27700" in ds.cf["foo"].coords + # Also check that coordinate variables are included + assert "latitude" in ds.cf["foo"].coords + assert "longitude" in ds.cf["foo"].coords + assert "x27700" in ds.cf["foo"].coords + assert "y27700" in ds.cf["foo"].coords + + # Test that accessing grid_mapping with cf indexing raises an error for multiple mappings + with pytest.raises( + KeyError, match="Receive multiple variables for key 'grid_mapping'" + ): + da.cf["grid_mapping"] + + # Test that DataArrays don't support list indexing + with pytest.raises( + KeyError, match="Cannot use an Iterable of keys with DataArrays" + ): + da.cf[["grid_mapping"]] + + # But Dataset should support list indexing and return all grid mappings and coordinates + result = ds.cf[["foo", "grid_mapping"]] + assert "crs_4326" in result.coords + assert "spatial_ref" in result.coords + assert "crs_27700" in result.coords + assert "foo" in result.data_vars + + def test_bad_grid_mapping_attribute(): ds = rotds.copy(deep=False) ds.temp.attrs["grid_mapping"] = "foo" From a35ba2ce4a0354580bcc9ef311a23299b39204d6 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 3 Sep 2025 09:46:22 -0600 Subject: [PATCH 02/14] fix envs --- cf_xarray/datasets.py | 68 +++++++++++++++++---------------- ci/environment-all-min-deps.yml | 1 + ci/environment.yml | 1 + ci/upstream-dev-env.yml | 1 + 4 files changed, 39 insertions(+), 32 deletions(-) diff --git a/cf_xarray/datasets.py b/cf_xarray/datasets.py index 409a1afe..52741b88 100644 --- a/cf_xarray/datasets.py +++ b/cf_xarray/datasets.py @@ -1,6 +1,5 @@ import numpy as np import xarray as xr -from pyproj import CRS airds = xr.tutorial.open_dataset("air_temperature").isel(time=slice(4), lon=slice(50)) airds.air.attrs["cell_measures"] = "area: cell_area" @@ -751,37 +750,42 @@ def _create_inexact_bounds(): ) -hrrrds = xr.Dataset() -hrrrds["foo"] = ( - ("x", "y"), - np.arange(200).reshape((10, 20)), - { - "grid_mapping": "spatial_ref: crs_4326: latitude longitude crs_27700: x27700 y27700" - }, -) -hrrrds.coords["spatial_ref"] = ((), 0, CRS.from_epsg(3035).to_cf()) -hrrrds.coords["crs_4326"] = ((), 0, CRS.from_epsg(4326).to_cf()) -hrrrds.coords["crs_27700"] = ((), 0, CRS.from_epsg(27700).to_cf()) -hrrrds.coords["latitude"] = ( - ("x", "y"), - np.ones((10, 20)), - {"standard_name": "latitude"}, -) -hrrrds.coords["longitude"] = ( - ("x", "y"), - np.zeros((10, 20)), - {"standard_name": "longitude"}, -) -hrrrds.coords["y27700"] = ( - ("x", "y"), - np.ones((10, 20)), - {"standard_name": "projected_x_coordinate"}, -) -hrrrds.coords["x27700"] = ( - ("x", "y"), - np.zeros((10, 20)), - {"standard_name": "projected_y_coordinate"}, -) +try: + from pyproj import CRS + + hrrrds = xr.Dataset() + hrrrds["foo"] = ( + ("x", "y"), + np.arange(200).reshape((10, 20)), + { + "grid_mapping": "spatial_ref: crs_4326: latitude longitude crs_27700: x27700 y27700" + }, + ) + hrrrds.coords["spatial_ref"] = ((), 0, CRS.from_epsg(3035).to_cf()) + hrrrds.coords["crs_4326"] = ((), 0, CRS.from_epsg(4326).to_cf()) + hrrrds.coords["crs_27700"] = ((), 0, CRS.from_epsg(27700).to_cf()) + hrrrds.coords["latitude"] = ( + ("x", "y"), + np.ones((10, 20)), + {"standard_name": "latitude"}, + ) + hrrrds.coords["longitude"] = ( + ("x", "y"), + np.zeros((10, 20)), + {"standard_name": "longitude"}, + ) + hrrrds.coords["y27700"] = ( + ("x", "y"), + np.ones((10, 20)), + {"standard_name": "projected_x_coordinate"}, + ) + hrrrds.coords["x27700"] = ( + ("x", "y"), + np.zeros((10, 20)), + {"standard_name": "projected_y_coordinate"}, + ) +except ImportError: + pass def point_dataset(): diff --git a/ci/environment-all-min-deps.yml b/ci/environment-all-min-deps.yml index 8b431eb7..aacb5567 100644 --- a/ci/environment-all-min-deps.yml +++ b/ci/environment-all-min-deps.yml @@ -20,5 +20,6 @@ dependencies: - shapely - xarray==2023.09.0 - pip + - pyproj - pip: - pytest-pretty diff --git a/ci/environment.yml b/ci/environment.yml index bcfb0780..49f10b63 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -19,6 +19,7 @@ dependencies: - scipy - shapely - xarray + - pyproj - pip - pip: - pytest-pretty diff --git a/ci/upstream-dev-env.yml b/ci/upstream-dev-env.yml index 364b071e..0523ed4e 100644 --- a/ci/upstream-dev-env.yml +++ b/ci/upstream-dev-env.yml @@ -13,6 +13,7 @@ dependencies: - pooch - rich - shapely + - pyproj - pip - pip: - pytest-pretty From fa83bcbdbe938cbef47454f565aa4cc7040ce9f5 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 3 Sep 2025 09:49:05 -0600 Subject: [PATCH 03/14] Add pyproj based skipping --- cf_xarray/tests/__init__.py | 1 + cf_xarray/tests/test_accessor.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/cf_xarray/tests/__init__.py b/cf_xarray/tests/__init__.py index 8c83df3a..b9f1b8dd 100644 --- a/cf_xarray/tests/__init__.py +++ b/cf_xarray/tests/__init__.py @@ -69,3 +69,4 @@ def LooseVersion(vstring): has_pooch, requires_pooch = _importorskip("pooch") _, requires_rich = _importorskip("rich") has_regex, requires_regex = _importorskip("regex") +has_pyproj, requires_pyproj = _importorskip("pyproj") diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index eac7b0d8..8a78454d 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -29,7 +29,6 @@ flag_indep_uint16, flag_mix, forecast, - hrrrds, mollwds, multiple, popds, @@ -45,6 +44,7 @@ requires_cftime, requires_pint, requires_pooch, + requires_pyproj, requires_regex, requires_rich, requires_scipy, @@ -1067,8 +1067,9 @@ def test_grid_mappings(): assert "rotated_pole" in ds.coords +@requires_pyproj def test_multiple_grid_mapping_attribute(): - ds = hrrrds + from ..datasets import hrrrds as ds # Test Dataset grid_mapping_names # Now includes British National Grid (EPSG:27700) which has grid_mapping_name From 74c4c6118e802a79f3137de01eaf9cba27ca42c5 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 3 Sep 2025 09:57:08 -0600 Subject: [PATCH 04/14] cleanup --- cf_xarray/accessor.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 3691f535..161dc77d 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -4,7 +4,7 @@ import inspect import itertools import re -from collections import ChainMap, namedtuple +from collections import ChainMap, defaultdict, namedtuple from collections.abc import ( Callable, Hashable, @@ -2990,21 +2990,11 @@ def grid_mapping_names(self) -> dict[str, list[str]]: # Parse potentially multiple grid mappings grid_mapping_var_names = _parse_grid_mapping_attribute(grid_mapping_attr) - results = {} - for grid_mapping_var_name in grid_mapping_var_names: - # First check if it's in the DataArray's coords (for multiple grid mappings - # that are coordinates of the DataArray) - if grid_mapping_var_name in da.coords: - grid_mapping_var = da.coords[grid_mapping_var_name] - if "grid_mapping_name" in grid_mapping_var.attrs: - gmn = grid_mapping_var.attrs["grid_mapping_name"] - if gmn not in results: - results[gmn] = [grid_mapping_var_name] - else: - results[gmn].append(grid_mapping_var_name) - # For standalone DataArrays, the grid mapping variables may not be available - # This is a limitation of the xarray data model - when you extract a DataArray - # from a Dataset, it doesn't carry over non-coordinate variables + results = defaultdict(list) + for grid_mapping_var_name in grid_mapping_var_names and set(da.coords): + grid_mapping_var = da.coords[grid_mapping_var_name] + if gmn := grid_mapping_var.attrs.get("grid_mapping_name"): + results[gmn].append(grid_mapping_var_name) return results From c5dfc7acac3efce7266e0e2e15ab8437fd3ae6c3 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 3 Sep 2025 09:59:44 -0600 Subject: [PATCH 05/14] one more --- cf_xarray/accessor.py | 27 ++------------------------- cf_xarray/tests/test_accessor.py | 6 ++++++ 2 files changed, 8 insertions(+), 25 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 161dc77d..d55c909d 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -492,11 +492,9 @@ def _get_grid_mapping_name(obj: DataArray | Dataset, key: str) -> list[str]: results = set() for var in variables.values(): attrs_or_encoding = ChainMap(var.attrs, var.encoding) - if "grid_mapping" in attrs_or_encoding: - grid_mapping_attr = attrs_or_encoding["grid_mapping"] + if grid_mapping_attr := attrs_or_encoding.get("grid_mapping"): # Parse potentially multiple grid mappings grid_mapping_var_names = _parse_grid_mapping_attribute(grid_mapping_attr) - for grid_mapping_var_name in grid_mapping_var_names: if grid_mapping_var_name not in variables: raise ValueError( @@ -1979,28 +1977,7 @@ def get_associated_variable_names( if grid_mapping_attr := attrs_or_encoding.get("grid_mapping", None): # Parse grid mapping variables using the same function grid_mapping_vars = _parse_grid_mapping_attribute(grid_mapping_attr) - coords["grid_mapping"] = grid_mapping_vars - - # Extract coordinate variables using regex - if ":" in grid_mapping_attr: - # Pattern to find coordinate variables: words that come after ":" but before next grid mapping variable - # This captures coordinate variables between grid mapping sections - coord_pattern = r":\s+([^:]+?)(?=\s+[a-zA-Z_][a-zA-Z0-9_]*\s*:|$)" - coord_matches = re.findall(coord_pattern, grid_mapping_attr) - - for coord_section in coord_matches: - # Split each coordinate section and add valid coordinate names - coord_vars = coord_section.split() - # Filter out grid mapping variable names that might have been captured - coord_vars = [ - var - for var in coord_vars - if not ( - var.startswith(("crs_", "spatial_", "proj_")) - and var in grid_mapping_vars - ) - ] - coords["coordinates"].extend(coord_vars) + coords["grid_mapping"] = cast(list[Hashable], grid_mapping_vars) more: Sequence[Hashable] = () if geometry_var := attrs_or_encoding.get("geometry", None): diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index 8a78454d..8f025a63 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -1088,6 +1088,12 @@ def test_multiple_grid_mapping_attribute(): "transverse_mercator": ["crs_27700"], } + assert da.cf.get_associated_variable_names()["grid_mapping"] == [ + "latitude_longitude", + "lambert_azimuthal_equal_area", + "transverse_mercator", + ] + # Test that grid_mapping_name raises an error with multiple mappings with pytest.raises( ValueError, From d6ae4b15dce2867a1c6a740de5f53cdf5288b45b Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 3 Sep 2025 10:12:27 -0600 Subject: [PATCH 06/14] fixes --- cf_xarray/accessor.py | 73 ++++++++++++++++++++++++-------- cf_xarray/tests/test_accessor.py | 8 ++-- 2 files changed, 59 insertions(+), 22 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index d55c909d..ecf49ac4 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -440,7 +440,7 @@ def _get_bounds(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]: return list(results) -def _parse_grid_mapping_attribute(grid_mapping_attr: str) -> list[str]: +def _parse_grid_mapping_attribute(grid_mapping_attr: str) -> dict[str, list[str]]: """ Parse a grid_mapping attribute that may contain multiple grid mappings. @@ -448,22 +448,52 @@ def _parse_grid_mapping_attribute(grid_mapping_attr: str) -> list[str]: Multiple sections are separated by colons. Examples: - - Single: "spatial_ref" + - Single: "spatial_ref" -> {"spatial_ref": []} - Multiple: "spatial_ref: crs_4326: latitude longitude crs_27700: x27700 y27700" + -> {"spatial_ref": [], "crs_4326": ["latitude", "longitude"], "crs_27700": ["x27700", "y27700"]} - Returns a list of grid mapping variable names. + Returns a dictionary mapping grid mapping variable names to their associated coordinate variables. """ # Check if there are colons indicating multiple mappings if ":" not in grid_mapping_attr: - return [grid_mapping_attr.strip()] - - # Use regex to find grid mapping variable names - # Pattern matches: word at start OR word that comes after some coordinate names and before ":" - # This handles cases like "spatial_ref: crs_4326: latitude longitude crs_27700: x27700 y27700" - pattern = r"(?:^|\s)([a-zA-Z_][a-zA-Z0-9_]*)(?=\s*:)" - matches = re.findall(pattern, grid_mapping_attr) + return {grid_mapping_attr.strip(): []} + + # Use regex to parse the format + # First, find all grid mapping variables (words before colons) + grid_pattern = r"(?:^|\s)([a-zA-Z_][a-zA-Z0-9_]*)(?=\s*:)" + grid_mappings = re.findall(grid_pattern, grid_mapping_attr) + + if not grid_mappings: + return {grid_mapping_attr.strip(): []} + + result = {} + + # Now extract coordinates for each grid mapping + # Split the string to find what comes after each grid mapping variable + for i, gm in enumerate(grid_mappings): + # Pattern to capture everything after this grid mapping until the next one or end + if i < len(grid_mappings) - 1: + next_gm = grid_mappings[i + 1] + # Capture everything between current grid mapping and next one + coord_pattern = ( + rf"{re.escape(gm)}\s*:\s*([^:]*?)(?=\s+{re.escape(next_gm)}\s*:)" + ) + else: + # Last grid mapping - capture everything after it + coord_pattern = rf"{re.escape(gm)}\s*:\s*(.*)$" + + coord_match = re.search(coord_pattern, grid_mapping_attr) + if coord_match: + coord_text = coord_match.group(1).strip() + # Split coordinates and filter out any grid mapping names that might have been captured + coords = coord_text.split() if coord_text else [] + # Filter out the next grid mapping variable if it got captured + coords = [c for c in coords if c not in grid_mappings] + result[gm] = coords + else: + result[gm] = [] - return matches if matches else [grid_mapping_attr.strip()] + return result def _get_grid_mapping_name(obj: DataArray | Dataset, key: str) -> list[str]: @@ -494,8 +524,8 @@ def _get_grid_mapping_name(obj: DataArray | Dataset, key: str) -> list[str]: attrs_or_encoding = ChainMap(var.attrs, var.encoding) if grid_mapping_attr := attrs_or_encoding.get("grid_mapping"): # Parse potentially multiple grid mappings - grid_mapping_var_names = _parse_grid_mapping_attribute(grid_mapping_attr) - for grid_mapping_var_name in grid_mapping_var_names: + grid_mapping_dict = _parse_grid_mapping_attribute(grid_mapping_attr) + for grid_mapping_var_name in grid_mapping_dict.keys(): if grid_mapping_var_name not in variables: raise ValueError( f"{var} defines non-existing grid_mapping variable {grid_mapping_var_name}." @@ -1975,9 +2005,16 @@ def get_associated_variable_names( coords["grid"] = [grid] if grid_mapping_attr := attrs_or_encoding.get("grid_mapping", None): - # Parse grid mapping variables using the same function - grid_mapping_vars = _parse_grid_mapping_attribute(grid_mapping_attr) - coords["grid_mapping"] = cast(list[Hashable], grid_mapping_vars) + # Parse grid mapping variables and their coordinates + grid_mapping_dict = _parse_grid_mapping_attribute(grid_mapping_attr) + coords["grid_mapping"] = cast( + list[Hashable], list(grid_mapping_dict.keys()) + ) + + # Add coordinate variables from the grid mapping + for coord_vars in grid_mapping_dict.values(): + if coord_vars: + coords["coordinates"].extend(coord_vars) more: Sequence[Hashable] = () if geometry_var := attrs_or_encoding.get("geometry", None): @@ -2965,10 +3002,10 @@ def grid_mapping_names(self) -> dict[str, list[str]]: return {} # Parse potentially multiple grid mappings - grid_mapping_var_names = _parse_grid_mapping_attribute(grid_mapping_attr) + grid_mapping_dict = _parse_grid_mapping_attribute(grid_mapping_attr) results = defaultdict(list) - for grid_mapping_var_name in grid_mapping_var_names and set(da.coords): + for grid_mapping_var_name in grid_mapping_dict.keys() & set(da.coords): grid_mapping_var = da.coords[grid_mapping_var_name] if gmn := grid_mapping_var.attrs.get("grid_mapping_name"): results[gmn].append(grid_mapping_var_name) diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index 8f025a63..4023476e 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -1088,10 +1088,10 @@ def test_multiple_grid_mapping_attribute(): "transverse_mercator": ["crs_27700"], } - assert da.cf.get_associated_variable_names()["grid_mapping"] == [ - "latitude_longitude", - "lambert_azimuthal_equal_area", - "transverse_mercator", + assert ds.cf.get_associated_variable_names("foo")["grid_mapping"] == [ + "spatial_ref", + "crs_4326", + "crs_27700", ] # Test that grid_mapping_name raises an error with multiple mappings From 5410abf89c1788cd876b54b5dd5bb0ebf3f3f5c3 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 3 Sep 2025 11:16:36 -0600 Subject: [PATCH 07/14] Add GridMapping dataclass --- cf_xarray/__init__.py | 2 +- cf_xarray/accessor.py | 213 ++++++++++++++++++++++++++++++- cf_xarray/tests/test_accessor.py | 110 ++++++++++++++++ ci/doc.yml | 2 + doc/Makefile | 2 +- doc/api.rst | 3 + doc/grid_mappings.md | 101 ++++++++++++--- 7 files changed, 411 insertions(+), 22 deletions(-) diff --git a/cf_xarray/__init__.py b/cf_xarray/__init__.py index 6800c2c7..7f6e06ba 100644 --- a/cf_xarray/__init__.py +++ b/cf_xarray/__init__.py @@ -3,7 +3,7 @@ from . import geometry as geometry from . import sgrid # noqa -from .accessor import CFAccessor # noqa +from .accessor import CFAccessor, GridMapping # noqa from .coding import ( # noqa decode_compress_to_multi_index, encode_multi_index_as_compress, diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index ecf49ac4..c83df7c0 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -13,6 +13,7 @@ MutableMapping, Sequence, ) +from dataclasses import dataclass from datetime import datetime from typing import ( Any, @@ -82,6 +83,61 @@ FlagParam = namedtuple("FlagParam", ["flag_mask", "flag_value"]) + +@dataclass(frozen=True, kw_only=True) +class GridMapping: + """ + Represents a CF grid mapping with its properties and associated coordinate variables. + + Attributes + ---------- + name : str + The CF grid mapping name (e.g., ``'latitude_longitude'``, ``'transverse_mercator'``) + crs : pyproj.CRS + The coordinate reference system object + array : xarray.DataArray + The grid mapping variable as a DataArray containing the CRS parameters + coordinates : tuple[str, ...] + Names of coordinate variables associated with this grid mapping. For grid mappings + that are explicitly listed with coordinates in the grid_mapping attribute + (e.g., ``'spatial_ref: crs_4326: latitude longitude'``), this contains those coordinates. + For grid mappings (e.g. ``spatial_ref``) that don't explicitly specify coordinates, + this falls back to the dimension names of the data variable that references + this grid mapping. + """ + + name: str + crs: Any # really pyproj.CRS + array: xr.DataArray + coordinates: tuple[str, ...] + + def __repr__(self) -> str: + # Try to get EPSG code first, fallback to shorter description + try: + if hasattr(self.crs, "to_epsg") and self.crs.to_epsg(): + crs_repr = f"" + else: + # Use the name if available, otherwise authority:code + crs_name = getattr(self.crs, "name", str(self.crs)[:50] + "...") + crs_repr = f"" + except Exception: + # Fallback to generic representation + crs_repr = "" + + # Short array representation - name and shape + array_repr = f"" + + # Format coordinates nicely + coords_repr = f"({', '.join(repr(c) for c in self.coordinates)})" + + return ( + f"GridMapping(name={self.name!r}, " + f"crs={crs_repr}, " + f"array={array_repr}, " + f"coordinates={coords_repr})" + ) + + #: Classes wrapped by cf_xarray. _WRAPPED_CLASSES = (Resample, GroupBy, Rolling, Coarsen, Weighted) @@ -2406,6 +2462,160 @@ def add_canonical_attributes( return obj + def _create_grid_mapping( + self, + var_name: str, + obj_dataset: Dataset, + grid_mapping_dict: dict[str, list[str]], + ) -> GridMapping: + """ + Create a GridMapping dataclass instance from a grid mapping variable. + + Parameters + ---------- + var_name : str + Name of the grid mapping variable + obj_dataset : Dataset + Dataset containing the grid mapping variable + grid_mapping_dict : dict[str, list[str]] + Dictionary mapping grid mapping variable names to their coordinate variables + + Returns + ------- + GridMapping + GridMapping dataclass instance + + Notes + ----- + Assumes pyproj is available (should be checked by caller). + """ + from pyproj import ( + CRS, # Safe to import since grid_mappings property checks availability + ) + + var = obj_dataset._variables[var_name] + + # Create DataArray from Variable, preserving the name + # Use reset_coords(drop=True) to avoid coordinate conflicts + if var_name in obj_dataset.coords: + da = obj_dataset.coords[var_name].reset_coords(drop=True) + else: + da = obj_dataset[var_name].reset_coords(drop=True) + + # Get the CF grid mapping name from the variable's attributes + cf_name = var.attrs.get("grid_mapping_name", var_name) + + # Create CRS from the grid mapping variable + try: + crs = CRS.from_cf(var.attrs) + except Exception: + # If CRS creation fails, use None + crs = None + + # Get associated coordinate variables, fallback to dimension names + coordinates = grid_mapping_dict.get(var_name, []) + if not coordinates: + # For DataArrays, find the data variable that references this grid mapping + for _data_var_name, data_var in obj_dataset.data_vars.items(): + if "grid_mapping" in data_var.attrs: + gm_attr = data_var.attrs["grid_mapping"] + if var_name in gm_attr: + coordinates = list(data_var.dims) + break + + return GridMapping( + name=cf_name, crs=crs, array=da, coordinates=tuple(coordinates) + ) + + @property + def grid_mappings(self) -> tuple[GridMapping, ...]: + """ + Return a tuple of GridMapping objects for all grid mappings in this object. + + For DataArrays, the order in the tuple matches the order that grid mappings appear + in the grid_mapping attribute string. + + Parameters + ---------- + None + + Returns + ------- + tuple[GridMapping, ...] + Tuple of GridMapping dataclass instances, each containing: + - name: CF grid mapping name + - crs: pyproj.CRS object + - array: xarray.DataArray containing the grid mapping variable + - coordinates: tuple of coordinate variable names + + Raises + ------ + ImportError + If pyproj is not available. This property requires pyproj for CRS creation. + + Examples + -------- + >>> ds.cf.grid_mappings + (GridMapping(name='latitude_longitude', crs=, ...),) + + Notes + ----- + This property requires pyproj to be installed for creating CRS objects from + CF grid mapping parameters. Install with: ``conda install pyproj`` or + ``pip install pyproj``. + """ + # Check pyproj availability upfront + try: + import pyproj # noqa: F401 + except ImportError: + raise ImportError( + "pyproj is required for .cf.grid_mappings property. " + "Install with: conda install pyproj or pip install pyproj" + ) from None + # For DataArrays, preserve order from grid_mapping attribute + if isinstance(self._obj, DataArray) and "grid_mapping" in self._obj.attrs: + grid_mapping_dict = _parse_grid_mapping_attribute( + self._obj.attrs["grid_mapping"] + ) + # Get grid mappings in the order they appear in the string + ordered_var_names = list(grid_mapping_dict.keys()) + else: + # For Datasets, look for grid_mapping attributes in data variables + grid_mapping_dict = {} + ordered_var_names = [] + + # Search all data variables for grid_mapping attributes + for _var_name, var in self._obj.data_vars.items(): + if "grid_mapping" in var.attrs: + parsed = _parse_grid_mapping_attribute(var.attrs["grid_mapping"]) + grid_mapping_dict.update(parsed) + # Add variables in order they appear in this grid_mapping string + for gm_var in parsed.keys(): + if gm_var not in ordered_var_names: + ordered_var_names.append(gm_var) + + # If no grid_mapping attributes found in data vars, try grid_mapping_names property + if not ordered_var_names and hasattr(self, "grid_mapping_names"): + grid_mapping_names = self.grid_mapping_names + for var_names in grid_mapping_names.values(): + ordered_var_names.extend(var_names) + + if not ordered_var_names: + return () + + grid_mappings = [] + obj_dataset = self._maybe_to_dataset() + + for var_name in ordered_var_names: + if var_name not in obj_dataset._variables: + continue + + grid_mappings.append( + self._create_grid_mapping(var_name, obj_dataset, grid_mapping_dict) + ) + + return tuple(grid_mappings) + @xr.register_dataset_accessor("cf") class CFDatasetAccessor(CFAccessor): @@ -3009,8 +3219,7 @@ def grid_mapping_names(self) -> dict[str, list[str]]: grid_mapping_var = da.coords[grid_mapping_var_name] if gmn := grid_mapping_var.attrs.get("grid_mapping_name"): results[gmn].append(grid_mapping_var_name) - - return results + return dict(results) @property def grid_mapping_name(self) -> str: diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index 4023476e..b70e4667 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -1130,6 +1130,116 @@ def test_multiple_grid_mapping_attribute(): assert "foo" in result.data_vars +@requires_pyproj +def test_grid_mappings_property(): + """Test the .cf.grid_mappings property on both Dataset and DataArray.""" + from ..datasets import hrrrds + + ds = hrrrds + + # Test Dataset + grid_mappings = ds.cf.grid_mappings + assert len(grid_mappings) == 3 + + # Check that all expected grid mapping names are present + gm_names = {gm.name for gm in grid_mappings} + expected_names = { + "latitude_longitude", + "lambert_azimuthal_equal_area", + "transverse_mercator", + } + assert gm_names == expected_names + + # Test specific properties of each grid mapping + for gm in grid_mappings: + assert gm.crs is not None # Should have pyproj CRS + assert isinstance(gm.array, xr.DataArray) # Should be DataArray, not Variable + assert isinstance(gm.coordinates, tuple) + assert gm.array.name is not None # DataArray should preserve name + + # Check specific coordinate associations + if gm.name == "latitude_longitude": + assert gm.coordinates == ("latitude", "longitude") + elif gm.name == "transverse_mercator": + assert gm.coordinates == ("x27700", "y27700") + elif gm.name == "lambert_azimuthal_equal_area": + assert gm.coordinates == ( + "x", + "y", + ) # Falls back to data variable dimensions + + # Test DataArray + da = ds.foo + da_grid_mappings = da.cf.grid_mappings + assert len(da_grid_mappings) == 3 + + # DataArray should have the same grid mappings as Dataset + da_names = {gm.name for gm in da_grid_mappings} + assert da_names == expected_names + + # Check that coordinates are populated for DataArray too + for gm in da_grid_mappings: + assert len(gm.coordinates) > 0 # Should never be empty now + if gm.name == "lambert_azimuthal_equal_area": + assert gm.coordinates == ("x", "y") + + +@requires_pyproj +def test_grid_mappings_coordinates_attribute(): + """Test that coordinates attribute is always populated correctly for DataArray grid mappings.""" + from ..datasets import hrrrds + + ds = hrrrds + + # Focus on DataArray access + da = ds.foo + grid_mappings = da.cf.grid_mappings + assert len(grid_mappings) == 3 + + # Verify order preservation for DataArray (should match grid_mapping attribute order) + expected_order = [ + "lambert_azimuthal_equal_area", + "latitude_longitude", + "transverse_mercator", + ] + actual_order = [gm.name for gm in grid_mappings] + assert actual_order == expected_order, ( + f"DataArray order {actual_order} doesn't match expected {expected_order}" + ) + + for gm in grid_mappings: + # Coordinates should never be empty + assert len(gm.coordinates) > 0, ( + f"Grid mapping '{gm.name}' has empty coordinates" + ) + + # All coordinates should be strings + assert all(isinstance(coord, str) for coord in gm.coordinates), ( + f"Grid mapping '{gm.name}' has non-string coordinates: {gm.coordinates}" + ) + + # Test specific expected coordinates for each grid mapping + if gm.name == "latitude_longitude": + # Explicitly listed in grid_mapping attribute: "crs_4326: latitude longitude" + assert gm.coordinates == ("latitude", "longitude"), ( + f"Expected ('latitude', 'longitude'), got {gm.coordinates}" + ) + elif gm.name == "transverse_mercator": + # Explicitly listed in grid_mapping attribute: "crs_27700: x27700 y27700" + assert gm.coordinates == ("x27700", "y27700"), ( + f"Expected ('x27700', 'y27700'), got {gm.coordinates}" + ) + elif gm.name == "lambert_azimuthal_equal_area": + # Not explicitly listed, should fallback to DataArray dimensions + assert gm.coordinates == ("x", "y"), ( + f"Expected ('x', 'y') from DataArray dimensions, got {gm.coordinates}" + ) + # Verify these are actually the DataArray's dimensions + assert gm.coordinates == da.dims, ( + f"Fallback coordinates {gm.coordinates} don't match DataArray dims {da.dims}" + ) + + def test_bad_grid_mapping_attribute(): ds = rotds.copy(deep=False) ds.temp.attrs["grid_mapping"] = "foo" diff --git a/ci/doc.yml b/ci/doc.yml index 22640931..601e048e 100644 --- a/ci/doc.yml +++ b/ci/doc.yml @@ -10,6 +10,7 @@ dependencies: - xarray - sphinx - sphinx-copybutton + - sphinx-autobuild - numpydoc - sphinx-autosummary-accessors - ipython @@ -22,5 +23,6 @@ dependencies: - shapely - furo>=2024 - myst-nb + - pyproj - pip: - -e ../ diff --git a/doc/Makefile b/doc/Makefile index 89413fb5..637852b4 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -4,7 +4,7 @@ # You can set these variables from the command line, and also # from the environment for the first two. SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build +SPHINXBUILD ?= sphinx-autobuild SOURCEDIR = . BUILDDIR = _build diff --git a/doc/api.rst b/doc/api.rst index bc3d689b..bd91bd4b 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -16,6 +16,7 @@ Top-level API set_options encode_multi_index_as_compress decode_compress_to_multi_index + GridMapping Geometries ---------- @@ -56,6 +57,7 @@ Attributes DataArray.cf.coordinates DataArray.cf.formula_terms DataArray.cf.grid_mapping_name + DataArray.cf.grid_mappings DataArray.cf.is_flag_variable DataArray.cf.standard_names DataArray.cf.plot @@ -115,6 +117,7 @@ Attributes Dataset.cf.coordinates Dataset.cf.formula_terms Dataset.cf.grid_mapping_names + Dataset.cf.grid_mappings Dataset.cf.geometries Dataset.cf.standard_names diff --git a/doc/grid_mappings.md b/doc/grid_mappings.md index 15e48379..c253b5af 100644 --- a/doc/grid_mappings.md +++ b/doc/grid_mappings.md @@ -17,6 +17,8 @@ kernelspec: 1. [CF conventions on grid mappings and projections](https://cfconventions.org/Data/cf-conventions/cf-conventions-1.10/cf-conventions.html#grid-mappings-and-projections) 1. {py:attr}`Dataset.cf.grid_mapping_names` 1. {py:attr}`DataArray.cf.grid_mapping_name` +1. {py:attr}`DataArray.cf.grid_mappings` +1. {py:attr}`Dataset.cf.grid_mappings` ``` `cf_xarray` understands the concept of coordinate projections using the [grid_mapping](https://cfconventions.org/Data/cf-conventions/cf-conventions-1.10/cf-conventions.html#grid-mappings-and-projections) attribute convention. For example, the dataset might contain two sets of coordinates: @@ -28,16 +30,22 @@ Due to the projection, those real coordinates are probably 2D data variables. Th ## Extracting grid mapping info -### Dataset - To access `grid_mapping` attributes, consider this example: ```{code-cell} +%xmode minimal + +import numpy as np + +np.set_printoptions(threshold=10, edgeitems=2) + from cf_xarray.datasets import rotds rotds ``` +### Dataset + The related grid mappings can be discovered using `Dataset.cf.grid_mapping_names` which maps a ["grid mapping name"](http://cfconventions.org/cf-conventions/cf-conventions.html#appendix-grid-mappings) to the appropriate variable name: @@ -73,33 +81,90 @@ And to get the grid mapping variable da.cf["grid_mapping"] ``` +## Multiple grid mappings + +A somewhat niche feature in CF is the ability to specify multiple grid mappings e.g. coordinate locations in multiple coordinate systems associated with the same data. + +```{code-cell} +from cf_xarray.datasets import hrrrds + +hrrrds.cf["foo"] +``` + +This dataset has 3 grid mappings associated with the "foo" array. + +### Dataset + +Use `grid_mapping_names` (note the plural) to get them all: + +```{code-cell} +hrrrds.cf.grid_mapping_names +``` + +### DataArray + +Use `grid_mapping_names` (note the plural) to get them all: + +```{code-cell} +hrrrds.foo.cf.grid_mapping_names +``` + +Simply asking for one will raise an error + +```{code-cell} +--- +tags: [raises-exception] +--- +hrrrds.foo.cf.grid_mapping_name +``` + +All grid mapping variables and coordinate variables are in the "associated variable names" for that array: + +```{code-cell} +hrrrds.cf.get_associated_variable_names("foo") +``` + +## GridMapping structure + +A richer data structure {py:class}`~cf_xarray.GridMapping` is also available: + +```{code-cell} +hrrrds.foo.cf.grid_mappings +``` + +The ordering of this tuple matches that of the `grid_mapping` attribute + +```{code-cell} +grid = next(iter(hrrrds.foo.cf.grid_mappings)) +grid +``` + +The grid mapping DataArray can be extracted + +```{code-cell} +grid.array +``` + ## Use `grid_mapping` in projections +### pyproj + The grid mapping information use very useful in projections, e.g., for plotting. [pyproj](https://pyproj4.github.io/pyproj/stable/api/crs/crs.html#pyproj.crs.CRS.from_cf) understands CF conventions right away, e.g. -```python +```{code-cell} from pyproj import CRS CRS.from_cf(rotds.cf["grid_mapping"].attrs) ``` -gives you more details on the projection: +Alternatively use the {py:class}`~cf_xarray.GridMapping` object +```{code-cell} +grid = next(iter(rotds.cf.grid_mappings)) # there's only one, so extract it +grid.crs ``` - -Name: undefined -Axis Info [ellipsoidal]: -- lon[east]: Longitude (degree) -- lat[north]: Latitude (degree) -Area of Use: -- undefined -Coordinate Operation: -- name: Pole rotation (netCDF CF convention) -- method: Pole rotation (netCDF CF convention) -Datum: World Geodetic System 1984 -- Ellipsoid: WGS 84 -- Prime Meridian: Greenwich -``` + +### cartopy For use in cartopy, there is some more overhead due to [this issue](https://github.com/SciTools/cartopy/issues/2099). So you should select the right cartopy CRS and just feed in the grid mapping info: From 196225bd6dd94d3f9ff810ecf9d2b49ed40a1637 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 3 Sep 2025 11:20:52 -0600 Subject: [PATCH 08/14] Update doc/Makefile --- doc/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/Makefile b/doc/Makefile index 637852b4..89413fb5 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -4,7 +4,7 @@ # You can set these variables from the command line, and also # from the environment for the first two. SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-autobuild +SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build From 5aa287807526d21df064670fb669912530407c05 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 3 Sep 2025 11:30:00 -0600 Subject: [PATCH 09/14] fix types --- cf_xarray/accessor.py | 140 +++++++++++++++++++++--------------------- 1 file changed, 69 insertions(+), 71 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index c83df7c0..7e7c2ad3 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -97,7 +97,7 @@ class GridMapping: The coordinate reference system object array : xarray.DataArray The grid mapping variable as a DataArray containing the CRS parameters - coordinates : tuple[str, ...] + coordinates : tuple[Hashable, ...] Names of coordinate variables associated with this grid mapping. For grid mappings that are explicitly listed with coordinates in the grid_mapping attribute (e.g., ``'spatial_ref: crs_4326: latitude longitude'``), this contains those coordinates. @@ -109,7 +109,7 @@ class GridMapping: name: str crs: Any # really pyproj.CRS array: xr.DataArray - coordinates: tuple[str, ...] + coordinates: tuple[Hashable, ...] def __repr__(self) -> str: # Try to get EPSG code first, fallback to shorter description @@ -496,7 +496,7 @@ def _get_bounds(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]: return list(results) -def _parse_grid_mapping_attribute(grid_mapping_attr: str) -> dict[str, list[str]]: +def _parse_grid_mapping_attribute(grid_mapping_attr: str) -> dict[str, list[Hashable]]: """ Parse a grid_mapping attribute that may contain multiple grid mappings. @@ -522,7 +522,7 @@ def _parse_grid_mapping_attribute(grid_mapping_attr: str) -> dict[str, list[str] if not grid_mappings: return {grid_mapping_attr.strip(): []} - result = {} + result: dict[str, list[Hashable]] = {} # Now extract coordinates for each grid mapping # Split the string to find what comes after each grid mapping variable @@ -545,13 +545,76 @@ def _parse_grid_mapping_attribute(grid_mapping_attr: str) -> dict[str, list[str] coords = coord_text.split() if coord_text else [] # Filter out the next grid mapping variable if it got captured coords = [c for c in coords if c not in grid_mappings] - result[gm] = coords + result[gm] = coords # type: ignore[assignment] else: result[gm] = [] return result +def _create_grid_mapping( + var_name: str, + obj_dataset: Dataset, + grid_mapping_dict: dict[str, list[Hashable]], +) -> GridMapping: + """ + Create a GridMapping dataclass instance from a grid mapping variable. + + Parameters + ---------- + var_name : str + Name of the grid mapping variable + obj_dataset : Dataset + Dataset containing the grid mapping variable + grid_mapping_dict : dict[str, list[Hashable]] + Dictionary mapping grid mapping variable names to their coordinate variables + + Returns + ------- + GridMapping + GridMapping dataclass instance + + Notes + ----- + Assumes pyproj is available (should be checked by caller). + """ + from pyproj import ( + CRS, # Safe to import since grid_mappings property checks availability + ) + + var = obj_dataset._variables[var_name] + + # Create DataArray from Variable, preserving the name + # Use reset_coords(drop=True) to avoid coordinate conflicts + if var_name in obj_dataset.coords: + da = obj_dataset.coords[var_name].reset_coords(drop=True) + else: + da = obj_dataset[var_name].reset_coords(drop=True) + + # Get the CF grid mapping name from the variable's attributes + cf_name = var.attrs.get("grid_mapping_name", var_name) + + # Create CRS from the grid mapping variable + try: + crs = CRS.from_cf(var.attrs) + except Exception: + # If CRS creation fails, use None + crs = None + + # Get associated coordinate variables, fallback to dimension names + coordinates: list[Hashable] = grid_mapping_dict.get(var_name, []) + if not coordinates: + # For DataArrays, find the data variable that references this grid mapping + for _data_var_name, data_var in obj_dataset.data_vars.items(): + if "grid_mapping" in data_var.attrs: + gm_attr = data_var.attrs["grid_mapping"] + if var_name in gm_attr: + coordinates = list(data_var.dims) + break + + return GridMapping(name=cf_name, crs=crs, array=da, coordinates=tuple(coordinates)) + + def _get_grid_mapping_name(obj: DataArray | Dataset, key: str) -> list[str]: """ Translate from grid mapping name attribute to appropriate variable name. @@ -2462,71 +2525,6 @@ def add_canonical_attributes( return obj - def _create_grid_mapping( - self, - var_name: str, - obj_dataset: Dataset, - grid_mapping_dict: dict[str, list[str]], - ) -> GridMapping: - """ - Create a GridMapping dataclass instance from a grid mapping variable. - - Parameters - ---------- - var_name : str - Name of the grid mapping variable - obj_dataset : Dataset - Dataset containing the grid mapping variable - grid_mapping_dict : dict[str, list[str]] - Dictionary mapping grid mapping variable names to their coordinate variables - - Returns - ------- - GridMapping - GridMapping dataclass instance - - Notes - ----- - Assumes pyproj is available (should be checked by caller). - """ - from pyproj import ( - CRS, # Safe to import since grid_mappings property checks availability - ) - - var = obj_dataset._variables[var_name] - - # Create DataArray from Variable, preserving the name - # Use reset_coords(drop=True) to avoid coordinate conflicts - if var_name in obj_dataset.coords: - da = obj_dataset.coords[var_name].reset_coords(drop=True) - else: - da = obj_dataset[var_name].reset_coords(drop=True) - - # Get the CF grid mapping name from the variable's attributes - cf_name = var.attrs.get("grid_mapping_name", var_name) - - # Create CRS from the grid mapping variable - try: - crs = CRS.from_cf(var.attrs) - except Exception: - # If CRS creation fails, use None - crs = None - - # Get associated coordinate variables, fallback to dimension names - coordinates = grid_mapping_dict.get(var_name, []) - if not coordinates: - # For DataArrays, find the data variable that references this grid mapping - for _data_var_name, data_var in obj_dataset.data_vars.items(): - if "grid_mapping" in data_var.attrs: - gm_attr = data_var.attrs["grid_mapping"] - if var_name in gm_attr: - coordinates = list(data_var.dims) - break - - return GridMapping( - name=cf_name, crs=crs, array=da, coordinates=tuple(coordinates) - ) - @property def grid_mappings(self) -> tuple[GridMapping, ...]: """ @@ -2611,7 +2609,7 @@ def grid_mappings(self) -> tuple[GridMapping, ...]: continue grid_mappings.append( - self._create_grid_mapping(var_name, obj_dataset, grid_mapping_dict) + _create_grid_mapping(var_name, obj_dataset, grid_mapping_dict) ) return tuple(grid_mappings) From 87ccd423746f745cc2bb68d4f9a68552ad165fee Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 3 Sep 2025 16:55:50 -0600 Subject: [PATCH 10/14] Fix accidental data load --- cf_xarray/accessor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 7e7c2ad3..964df6f9 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -399,7 +399,7 @@ def _get_axis_coord(obj: DataArray | Dataset, key: str) -> list[str]: results.update((coord,)) if criterion == "units": # deal with pint-backed objects - units = getattr(var.data, "units", None) + units = getattr(var.variable._data, "units", None) if units in expected: results.update((coord,)) From 6696f042e2461eba84dfe8fb64b5f0bc9d75f0fc Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 4 Sep 2025 09:11:05 -0600 Subject: [PATCH 11/14] cleanup Co-authored-by: Claude --- cf_xarray/accessor.py | 47 +++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 964df6f9..9e97595d 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -554,7 +554,7 @@ def _parse_grid_mapping_attribute(grid_mapping_attr: str) -> dict[str, list[Hash def _create_grid_mapping( var_name: str, - obj_dataset: Dataset, + ds: Dataset, grid_mapping_dict: dict[str, list[Hashable]], ) -> GridMapping: """ @@ -576,41 +576,40 @@ def _create_grid_mapping( Notes ----- - Assumes pyproj is available (should be checked by caller). + Assumes pyproj is available. """ - from pyproj import ( - CRS, # Safe to import since grid_mappings property checks availability - ) + import pyproj - var = obj_dataset._variables[var_name] + var = ds._variables[var_name] # Create DataArray from Variable, preserving the name - # Use reset_coords(drop=True) to avoid coordinate conflicts - if var_name in obj_dataset.coords: - da = obj_dataset.coords[var_name].reset_coords(drop=True) - else: - da = obj_dataset[var_name].reset_coords(drop=True) + da = xr.DataArray(ds._variables[var_name], name=var_name) # Get the CF grid mapping name from the variable's attributes cf_name = var.attrs.get("grid_mapping_name", var_name) # Create CRS from the grid mapping variable - try: - crs = CRS.from_cf(var.attrs) - except Exception: - # If CRS creation fails, use None - crs = None + crs = pyproj.CRS.from_cf(var.attrs) # Get associated coordinate variables, fallback to dimension names coordinates: list[Hashable] = grid_mapping_dict.get(var_name, []) - if not coordinates: - # For DataArrays, find the data variable that references this grid mapping - for _data_var_name, data_var in obj_dataset.data_vars.items(): - if "grid_mapping" in data_var.attrs: - gm_attr = data_var.attrs["grid_mapping"] - if var_name in gm_attr: - coordinates = list(data_var.dims) - break + # """ + # In order to make use of a grid mapping to directly calculate latitude and longitude values + # it is necessary to associate the coordinate variables with the independent variables of the mapping. + # This is done by assigning a standard_name to the coordinate variable. + # The appropriate values of the standard_name depend on the grid mapping and are given in Appendix F, Grid Mappings. + # """ + if not coordinates and len(grid_mapping_dict) == 1: + if crs.to_cf().get("grid_mapping_name") == "rotated_latitude_longitude": + xname, yname = "grid_longitude", "grid_latitude" + elif crs.is_geographic: + xname, yname = "longitude", "latitude" + elif crs.is_projected: + xname, yname = "projection_x_coordinate", "projection_y_coordinate" + + x = apply_mapper(_get_with_standard_name, ds, xname, error=False, default=[[]]) + y = apply_mapper(_get_with_standard_name, ds, yname, error=False, default=[[]]) + coordinates = tuple(itertools.chain(x, y)) return GridMapping(name=cf_name, crs=crs, array=da, coordinates=tuple(coordinates)) From 8f0f492efea15036794af1b00da3bc440e1175c5 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 4 Sep 2025 09:11:05 -0600 Subject: [PATCH 12/14] cleanup Co-authored-by: Claude --- cf_xarray/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cf_xarray/datasets.py b/cf_xarray/datasets.py index 52741b88..b38d3ff1 100644 --- a/cf_xarray/datasets.py +++ b/cf_xarray/datasets.py @@ -758,7 +758,7 @@ def _create_inexact_bounds(): ("x", "y"), np.arange(200).reshape((10, 20)), { - "grid_mapping": "spatial_ref: crs_4326: latitude longitude crs_27700: x27700 y27700" + "grid_mapping": "spatial_ref: x y crs_4326: latitude longitude crs_27700: x27700 y27700" }, ) hrrrds.coords["spatial_ref"] = ((), 0, CRS.from_epsg(3035).to_cf()) From ea81c5e696cfb5fe3cb01af372729c0cf62f66ff Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 4 Sep 2025 09:28:38 -0600 Subject: [PATCH 13/14] fix types --- cf_xarray/accessor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 9e97595d..21feb6ff 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -609,7 +609,7 @@ def _create_grid_mapping( x = apply_mapper(_get_with_standard_name, ds, xname, error=False, default=[[]]) y = apply_mapper(_get_with_standard_name, ds, yname, error=False, default=[[]]) - coordinates = tuple(itertools.chain(x, y)) + coordinates = list(itertools.chain(x, y)) return GridMapping(name=cf_name, crs=crs, array=da, coordinates=tuple(coordinates)) From da3a55750f991ca68059b664fd6a0b3c44e7c148 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 4 Sep 2025 09:40:51 -0600 Subject: [PATCH 14/14] more test --- cf_xarray/tests/test_accessor.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index b70e4667..3b65fbc6 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -987,6 +987,7 @@ def test_get_bounds_dim_name() -> None: ds.cf.get_bounds_dim_name("longitude") +@requires_pyproj def test_grid_mappings(): ds = rotds.copy(deep=False) @@ -1036,6 +1037,26 @@ def test_grid_mappings(): # grid_mapping_name assert ds.cf["temp"].cf.grid_mapping_name == "rotated_latitude_longitude" + # Test .grid_mappings property with single grid mapping + grid_mappings = ds.cf.grid_mappings + assert len(grid_mappings) == 1 + gm = grid_mappings[0] + assert gm.name == "rotated_latitude_longitude" + assert gm.array.name == "rotated_pole" + assert gm.array.shape == () # scalar variable + assert isinstance(gm.coordinates, tuple) + # Should have rlon and rlat detected from standard names + assert gm.coordinates == ("rlon", "rlat") + + # Test .grid_mappings property on DataArray with propagated coords + da_grid_mappings = ds.cf["temp"].cf.grid_mappings + assert len(da_grid_mappings) == 1 + da_gm = da_grid_mappings[0] + assert da_gm.name == "rotated_latitude_longitude" + assert da_gm.array.name == "rotated_pole" + # Should also detect rlon and rlat from standard names + assert da_gm.coordinates == ("rlon", "rlat") + # what if there are really 2 grid mappins? ds["temp2"] = ds.temp ds["temp2"].attrs["grid_mapping"] = "rotated_pole2" @@ -1240,6 +1261,7 @@ def test_grid_mappings_coordinates_attribute(): ) +@requires_pyproj def test_bad_grid_mapping_attribute(): ds = rotds.copy(deep=False) ds.temp.attrs["grid_mapping"] = "foo" @@ -1255,6 +1277,10 @@ def test_bad_grid_mapping_attribute(): with pytest.warns(UserWarning): ds.cf.get_associated_variable_names("temp", error=False) + # Test .grid_mappings property with bad grid mapping - should return empty tuple + grid_mappings = ds.cf.grid_mappings + assert grid_mappings == () # No valid grid mappings since 'foo' doesn't exist + def test_docstring() -> None: assert "One of ('X'" in airds.cf.groupby.__doc__