From 2e2632341283c3c903876be1093889b92b955e7e Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 23 Jun 2020 16:56:58 -0600 Subject: [PATCH 1/3] Update pre-commit --- .pre-commit-config.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 26bf4803..da65e9d7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,15 +8,15 @@ repos: files: .+\.py$ # https://github.com/python/black#version-control-integration - repo: https://github.com/python/black - rev: stable + rev: 19.10b0 hooks: - id: black - repo: https://gitlab.com/pycqa/flake8 - rev: 3.7.9 + rev: 3.8.3 hooks: - id: flake8 - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.761 # Must match ci/requirements/*.yml + rev: v0.781 # Must match ci/requirements/*.yml hooks: - id: mypy # run this occasionally, ref discussion https://github.com/pydata/xarray/pull/3194 From 89dba500fb4649f9935876b91c0d9db52784e0f6 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 23 Jun 2020 16:59:03 -0600 Subject: [PATCH 2/3] Rework getitem for standard_name support --- cf_xarray/accessor.py | 161 ++++++++++++++++++++++++++----- cf_xarray/tests/test_accessor.py | 80 +++++++-------- 2 files changed, 174 insertions(+), 67 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 70231511..ffa6df6f 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -1,9 +1,20 @@ import functools import inspect import itertools +import textwrap from collections import ChainMap from contextlib import suppress -from typing import Callable, List, Mapping, MutableMapping, Optional, Set, Tuple, Union +from typing import ( + Callable, + Hashable, + List, + Mapping, + MutableMapping, + Optional, + Set, + Tuple, + Union, +) import xarray as xr from xarray import DataArray, Dataset @@ -106,6 +117,11 @@ ] +def _strip_none_list(lst: List[Optional[str]]) -> List[str]: + """ The mappers can return [None]. Strip that when necessary. Keeps mypy happy.""" + return [item for item in lst if item != [None]] # type: ignore + + def _get_axis_coord_single( var: Union[xr.DataArray, xr.Dataset], key: str, @@ -176,8 +192,8 @@ def _get_axis_coord( results: Set = set() for coord in search_in: for criterion, valid_values in coordinate_criteria.items(): - if key in valid_values: # type: ignore - expected = valid_values[key] # type: ignore + if key in valid_values: + expected = valid_values[key] if var.coords[coord].attrs.get(criterion, None) in expected: results.update((coord,)) @@ -246,6 +262,31 @@ def _get_measure( } +def _filter_by_standard_names(ds: xr.Dataset, name: Union[str, List[str]]) -> List[str]: + """ returns a list of variable names with standard names matching name. """ + if isinstance(name, str): + name = [name] + + varnames = [] + counts = dict.fromkeys(name, 0) + for vname, var in ds.variables.items(): + stdname = var.attrs.get("standard_name", None) + if stdname in name: + varnames.append(str(vname)) + counts[stdname] += 1 + + return varnames + + +def _get_list_standard_names(obj: xr.Dataset) -> List[str]: + """ Returns a sorted list of standard names in Dataset. """ + names = [] + for k, v in obj.variables.items(): + if "standard_name" in v.attrs: + names.append(v.attrs["standard_name"]) + return sorted(names) + + def _getattr( obj: Union[DataArray, Dataset], attr: str, @@ -503,6 +544,16 @@ def _describe(self): text += f"\t{measure}: unsupported\n" else: text += f"\t{measure}: {_get_measure(self._obj, measure, error=False, default=None)}\n" + + text += "\nStandard Names:\n" + if isinstance(self._obj, xr.DataArray): + text += "\tunsupported\n" + else: + stdnames = _get_list_standard_names(self._obj) + text += "\t" + text += "\n".join( + textwrap.wrap(f"{stdnames!r}", 70, break_long_words=False) + ) return text def describe(self): @@ -529,32 +580,96 @@ def get_valid_keys(self) -> Set[str]: ] if measures: varnames.append(*measures) + + if not isinstance(self._obj, xr.DataArray): + varnames.extend(_get_list_standard_names(self._obj)) return set(varnames) + def __getitem__(self, key: Union[str, List[str]]): + + kind = str(type(self._obj).__name__) + scalar_key = isinstance(key, str) + if scalar_key: + key = (key,) # type: ignore + + varnames: List[Hashable] = [] + coords: List[Hashable] = [] + successful = dict.fromkeys(key, False) + for k in key: + if k in _AXIS_NAMES + _COORD_NAMES: + names = _get_axis_coord(self._obj, k) + successful[k] = bool(names) + varnames.extend(_strip_none_list(names)) + coords.extend(_strip_none_list(names)) + elif k in _CELL_MEASURES: + if isinstance(self._obj, xr.Dataset): + raise NotImplementedError( + "Invalid key {k!r}. Cell measures not implemented for Dataset yet." + ) + else: + measure = _get_measure(self._obj, k) + successful[k] = bool(measure) + if measure: + varnames.append(measure) + elif not isinstance(self._obj, xr.DataArray): + stdnames = _filter_by_standard_names(self._obj, k) + successful[k] = bool(stdnames) + varnames.extend(stdnames) + coords.extend(list(set(stdnames).intersection(set(self._obj.coords)))) + + # these are not special names but could be variable names in underlying object + # we allow this so that we can return variables with appropriate CF auxiliary variables + varnames.extend([k for k, v in successful.items() if not v]) + assert len(varnames) > 0 + + try: + # TODO: make this a get_auxiliary_variables function + # make sure to set coordinate variables referred to in "coordinates" attribute + for name in varnames: + attrs = self._obj[name].attrs + if "coordinates" in attrs: + coords.extend(attrs.get("coordinates").split(" ")) + + if "cell_measures" in attrs: + measures = [ + _get_measure(self._obj[name], measure) + for measure in _CELL_MEASURES + if measure in attrs["cell_measures"] + ] + coords.extend(_strip_none_list(measures)) + + varnames.extend(coords) + if isinstance(self._obj, xr.DataArray): + ds = self._obj._to_temp_dataset() + else: + ds = self._obj + ds = ds.reset_coords()[varnames] + if isinstance(self._obj, DataArray): + if scalar_key and len(ds.variables) == 1: + # single dimension coordinates + return ds[list(ds.variables.keys())[0]].squeeze(drop=True) + elif scalar_key and len(ds.coords) > 1: + raise NotImplementedError( + "Not sure what to return when given scalar key for DataArray and it has multiple values. " + "Please open an issue." + ) + elif not scalar_key: + return ds.set_coords(coords) + else: + return ds.set_coords(coords) + + except KeyError: + raise KeyError( + f"{kind}.cf does not understand the key {k!r}. " + f"Use {kind}.cf.describe() to see a list of key names that can be interpreted." + ) + @xr.register_dataset_accessor("cf") class CFDatasetAccessor(CFAccessor): - def __getitem__(self, key): - if key in _AXIS_NAMES + _COORD_NAMES: - varnames = _get_axis_coord(self._obj, key) - return self._obj.reset_coords()[varnames].set_coords(varnames) - elif key in _CELL_MEASURES: - raise NotImplementedError("measures not implemented for Dataset yet.") - else: - raise KeyError( - f"Dataset.cf does not understand the key {key!r}. Use Dataset.cf.describe() to see a list of key names that can be interpreted." - ) + pass @xr.register_dataarray_accessor("cf") class CFDataArrayAccessor(CFAccessor): - def __getitem__(self, key): - if key in _AXIS_NAMES + _COORD_NAMES: - varname = _get_axis_coord_single(self._obj, key) - return self._obj[varname].reset_coords(drop=True) - elif key in _CELL_MEASURES: - return self._obj[_get_measure(self._obj, key)] - else: - raise KeyError( - f"DataArray.cf does not understand the key {key!r}. Use DataArray.cf.describe() to see a list of key names that can be interpreted." - ) + pass diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index 4aa52724..47d66e60 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -8,33 +8,43 @@ import cf_xarray # noqa from . import raise_if_dask_computes +from .datasets import airds, ds_no_attrs, popds mpl.use("Agg") -ds = xr.tutorial.open_dataset("air_temperature").isel(time=slice(4), lon=slice(50)) -ds.air.attrs["cell_measures"] = "area: cell_area" -ds.coords["cell_area"] = ( - xr.DataArray(np.cos(ds.lat * np.pi / 180)) * xr.ones_like(ds.lon) * 105e3 * 110e3 -) -ds_no_attrs = ds.copy(deep=True) -for variable in ds_no_attrs.variables: - ds_no_attrs[variable].attrs = {} -datasets = [ds, ds.chunk({"lat": 5})] -dataarrays = [ds.air, ds.air.chunk({"lat": 5})] +ds = airds +datasets = [airds, airds.chunk({"lat": 5})] +dataarrays = [airds.air, airds.air.chunk({"lat": 5})] objects = datasets + dataarrays def test_describe(): - actual = ds.cf._describe() + actual = airds.cf._describe() expected = ( "Axes:\n\tX: ['lon']\n\tY: ['lat']\n\tZ: [None]\n\tT: ['time']\n" "\nCoordinates:\n\tlongitude: ['lon']\n\tlatitude: ['lat']" "\n\tvertical: [None]\n\ttime: ['time']\n" "\nCell Measures:\n\tarea: unsupported\n\tvolume: unsupported\n" + "\nStandard Names:\n\t['air_temperature', 'latitude', 'longitude', 'time']" ) assert actual == expected +def test_getitem_standard_name(): + actual = airds.cf["air_temperature"] + expected = airds[["air"]] + assert_identical(actual, expected) + + ds = airds.copy(deep=True) + ds["air2"] = ds.air + actual = ds.cf["air_temperature"] + expected = ds[["air", "air2"]] + assert_identical(actual, expected) + + with pytest.raises(KeyError): + ds.air.cf["air_temperature"] + + @pytest.mark.parametrize("obj", objects) @pytest.mark.parametrize( "attr, xrkwargs, cfkwargs", @@ -128,7 +138,7 @@ def test_kwargs_expand_key_to_multiple_keys(): @pytest.mark.parametrize( "obj, expected", [ - (ds, set(("latitude", "longitude", "time", "X", "Y", "T"))), + (ds, set(("latitude", "longitude", "time", "X", "Y", "T", "air_temperature"))), (ds.air, set(("latitude", "longitude", "time", "X", "Y", "T", "area"))), (ds_no_attrs.air, set()), ], @@ -146,6 +156,19 @@ def test_args_methods(obj): assert_identical(expected, actual) +def test_dataarray_getitem(): + + air = airds.air + air.name = None + + assert_identical(air.cf["longitude"], air["lon"]) + assert_identical(air.cf[["longitude"]], air["lon"].reset_coords()) + assert_identical( + air.cf[["longitude", "latitude"]], + air.to_dataset(name="air").drop_vars("cell_area")[["lon", "lat"]], + ) + + @pytest.mark.parametrize("obj", dataarrays) def test_dataarray_plot(obj): @@ -211,38 +234,7 @@ def test_getitem_errors(obj,): def test_getitem_uses_coordinates(): # POP-like dataset - ds = xr.Dataset() - ds.coords["TLONG"] = ( - ("nlat", "nlon"), - np.ones((20, 30)), - {"axis": "X", "units": "degrees_east"}, - ) - ds.coords["TLAT"] = ( - ("nlat", "nlon"), - 2 * np.ones((20, 30)), - {"axis": "Y", "units": "degrees_north"}, - ) - ds.coords["ULONG"] = ( - ("nlat", "nlon"), - 0.5 * np.ones((20, 30)), - {"axis": "X", "units": "degrees_east"}, - ) - ds.coords["ULAT"] = ( - ("nlat", "nlon"), - 2.5 * np.ones((20, 30)), - {"axis": "Y", "units": "degrees_north"}, - ) - ds["UVEL"] = ( - ("nlat", "nlon"), - np.ones((20, 30)) * 15, - {"coordinates": "ULONG ULAT"}, - ) - ds["TEMP"] = ( - ("nlat", "nlon"), - np.ones((20, 30)) * 15, - {"coordinates": "TLONG TLAT"}, - ) - + ds = popds assert_identical( ds.cf["X"], ds.reset_coords()[["ULONG", "TLONG"]].set_coords(["ULONG", "TLONG"]) ) From a742dfeec03d8d40550f14e520bf480ba35491dc Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 23 Jun 2020 17:04:20 -0600 Subject: [PATCH 3/3] Add datasets.py --- cf_xarray/tests/datasets.py | 49 +++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 cf_xarray/tests/datasets.py diff --git a/cf_xarray/tests/datasets.py b/cf_xarray/tests/datasets.py new file mode 100644 index 00000000..c2588c8a --- /dev/null +++ b/cf_xarray/tests/datasets.py @@ -0,0 +1,49 @@ +import numpy as np +import xarray as xr + +airds = xr.tutorial.open_dataset("air_temperature").isel(time=slice(4), lon=slice(50)) +airds.air.attrs["cell_measures"] = "area: cell_area" +airds.air.attrs["standard_name"] = "air_temperature" +airds.coords["cell_area"] = ( + xr.DataArray(np.cos(airds.lat * np.pi / 180)) + * xr.ones_like(airds.lon) + * 105e3 + * 110e3 +) + +ds_no_attrs = airds.copy(deep=True) +for variable in ds_no_attrs.variables: + ds_no_attrs[variable].attrs = {} + + +popds = xr.Dataset() +popds.coords["TLONG"] = ( + ("nlat", "nlon"), + np.ones((20, 30)), + {"axis": "X", "units": "degrees_east"}, +) +popds.coords["TLAT"] = ( + ("nlat", "nlon"), + 2 * np.ones((20, 30)), + {"axis": "Y", "units": "degrees_north"}, +) +popds.coords["ULONG"] = ( + ("nlat", "nlon"), + 0.5 * np.ones((20, 30)), + {"axis": "X", "units": "degrees_east"}, +) +popds.coords["ULAT"] = ( + ("nlat", "nlon"), + 2.5 * np.ones((20, 30)), + {"axis": "Y", "units": "degrees_north"}, +) +popds["UVEL"] = ( + ("nlat", "nlon"), + np.ones((20, 30)) * 15, + {"coordinates": "ULONG ULAT", "standard_name": "sea_water_x_velocity"}, +) +popds["TEMP"] = ( + ("nlat", "nlon"), + np.ones((20, 30)) * 15, + {"coordinates": "TLONG TLAT", "standard_name": "sea_water_potential_temperature"}, +)