diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 24fa8cb7..70231511 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -2,6 +2,7 @@ import inspect import itertools from collections import ChainMap +from contextlib import suppress from typing import Callable, List, Mapping, MutableMapping, Optional, Set, Tuple, Union import xarray as xr @@ -226,6 +227,13 @@ def _get_measure( else: return default measures = dict(zip(strings[slice(0, None, 2)], strings[slice(1, None, 2)])) + if key not in measures: + if error: + raise KeyError( + f"Cell measure {key!r} not found. Please use .cf.describe() to see a list of key names that can be interpreted." + ) + else: + return default return measures[key] @@ -244,6 +252,7 @@ def _getattr( accessor: "CFAccessor", key_mappers: Mapping[str, Mapper], wrap_classes: bool = False, + extra_decorator: Callable = None, ): """ Common getattr functionality. @@ -261,13 +270,17 @@ def _getattr( Only True for the high level CFAccessor. Facilitates code reuse for _CFWrappedClass and _CFWrapppedPlotMethods For both of those, wrap_classes is False. + extra_decorator: Callable (optional) + An extra decorator, if necessary. This is used by _CFPlotMethods to set default + kwargs based on CF attributes. """ - func = getattr(obj, attr) + func: Callable = getattr(obj, attr) @functools.wraps(func) def wrapper(*args, **kwargs): arguments = accessor._process_signature(func, args, kwargs, key_mappers) - result = func(**arguments) + final_func = extra_decorator(func) if extra_decorator else func + result = final_func(**arguments) if wrap_classes and isinstance(result, _WRAPPED_CLASSES): result = _CFWrappedClass(result, accessor) @@ -312,6 +325,40 @@ def __init__(self, obj, accessor): self.accessor = accessor self._keys = ("x", "y", "hue", "col", "row") + def _plot_decorator(self, func): + """ + This decorator is used to set kwargs on plotting functions. + """ + valid_keys = self.accessor.get_valid_keys() + + @functools.wraps(func) + def _plot_wrapper(*args, **kwargs): + if "x" in kwargs: + if kwargs["x"] in valid_keys: + xvar = self.accessor[kwargs["x"]] + else: + xvar = self._obj[kwargs["x"]] + if "positive" in xvar.attrs: + if xvar.attrs["positive"] == "down": + kwargs.setdefault("xincrease", False) + else: + kwargs.setdefault("xincrease", True) + + if "y" in kwargs: + if kwargs["y"] in valid_keys: + yvar = self.accessor[kwargs["y"]] + else: + yvar = self._obj[kwargs["y"]] + if "positive" in yvar.attrs: + if yvar.attrs["positive"] == "down": + kwargs.setdefault("yincrease", False) + else: + kwargs.setdefault("yincrease", True) + + return func(*args, **kwargs) + + return _plot_wrapper + def __call__(self, *args, **kwargs): plot = _getattr( obj=self._obj, @@ -319,7 +366,7 @@ def __call__(self, *args, **kwargs): accessor=self.accessor, key_mappers=dict.fromkeys(self._keys, _get_axis_coord_single), ) - return plot(*args, **kwargs) + return self._plot_decorator(plot)(*args, **kwargs) def __getattr__(self, attr): return _getattr( @@ -327,6 +374,9 @@ def __getattr__(self, attr): attr=attr, accessor=self.accessor, key_mappers=dict.fromkeys(self._keys, _get_axis_coord_single), + # TODO: "extra_decorator" is more complex than I would like it to be. + # Not sure if there is a better way though + extra_decorator=self._plot_decorator, ) @@ -458,6 +508,29 @@ def _describe(self): def describe(self): print(self._describe()) + def get_valid_keys(self) -> Set[str]: + """ + Returns valid keys for .cf[] + + Returns + ------- + Set of valid key names that can be used with __getitem__ or .cf[key]. + """ + varnames = [ + key + for key in _AXIS_NAMES + _COORD_NAMES + if _get_axis_coord(self._obj, key, error=False, default=None) != [None] + ] + with suppress(NotImplementedError): + measures = [ + key + for key in _CELL_MEASURES + if _get_measure(self._obj, key, error=False) is not None + ] + if measures: + varnames.append(*measures) + return set(varnames) + @xr.register_dataset_accessor("cf") class CFDatasetAccessor(CFAccessor): diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index 5f9b3b30..4aa52724 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -15,6 +15,10 @@ ds.coords["cell_area"] = ( xr.DataArray(np.cos(ds.lat * np.pi / 180)) * xr.ones_like(ds.lon) * 105e3 * 110e3 ) +ds_no_attrs = ds.copy(deep=True) +for variable in ds_no_attrs.variables: + ds_no_attrs[variable].attrs = {} + datasets = [ds, ds.chunk({"lat": 5})] dataarrays = [ds.air, ds.air.chunk({"lat": 5})] objects = datasets + dataarrays @@ -121,6 +125,19 @@ def test_kwargs_expand_key_to_multiple_keys(): assert_identical(actual.mean(), expected.mean()) +@pytest.mark.parametrize( + "obj, expected", + [ + (ds, set(("latitude", "longitude", "time", "X", "Y", "T"))), + (ds.air, set(("latitude", "longitude", "time", "X", "Y", "T", "area"))), + (ds_no_attrs.air, set()), + ], +) +def test_get_valid_keys(obj, expected): + actual = obj.cf.get_valid_keys() + assert actual == expected + + @pytest.mark.parametrize("obj", objects) def test_args_methods(obj): with raise_if_dask_computes(): @@ -234,3 +251,15 @@ def test_getitem_uses_coordinates(): ) assert_identical(ds.UVEL.cf["X"], ds["ULONG"].reset_coords(drop=True)) assert_identical(ds.TEMP.cf["X"], ds["TLONG"].reset_coords(drop=True)) + + +def test_plot_xincrease_yincrease(): + ds = xr.tutorial.open_dataset("air_temperature").isel(time=slice(4), lon=slice(50)) + ds.lon.attrs["positive"] = "down" + ds.lat.attrs["positive"] = "down" + + f, ax = plt.subplots(1, 1) + ds.air.isel(time=1).cf.plot(ax=ax, x="X", y="Y") + + for lim in [ax.get_xlim(), ax.get_ylim()]: + assert lim[0] > lim[1]