Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 76 additions & 3 deletions cf_xarray/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import itertools
from collections import ChainMap
from contextlib import suppress
from typing import Callable, List, Mapping, MutableMapping, Optional, Set, Tuple, Union

import xarray as xr
Expand Down Expand Up @@ -226,6 +227,13 @@ def _get_measure(
else:
return default
measures = dict(zip(strings[slice(0, None, 2)], strings[slice(1, None, 2)]))
if key not in measures:
if error:
raise KeyError(
f"Cell measure {key!r} not found. Please use .cf.describe() to see a list of key names that can be interpreted."
)
else:
return default
return measures[key]


Expand All @@ -244,6 +252,7 @@ def _getattr(
accessor: "CFAccessor",
key_mappers: Mapping[str, Mapper],
wrap_classes: bool = False,
extra_decorator: Callable = None,
):
"""
Common getattr functionality.
Expand All @@ -261,13 +270,17 @@ def _getattr(
Only True for the high level CFAccessor.
Facilitates code reuse for _CFWrappedClass and _CFWrapppedPlotMethods
For both of those, wrap_classes is False.
extra_decorator: Callable (optional)
An extra decorator, if necessary. This is used by _CFPlotMethods to set default
kwargs based on CF attributes.
"""
func = getattr(obj, attr)
func: Callable = getattr(obj, attr)

@functools.wraps(func)
def wrapper(*args, **kwargs):
arguments = accessor._process_signature(func, args, kwargs, key_mappers)
result = func(**arguments)
final_func = extra_decorator(func) if extra_decorator else func
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only design I could think of but it seems a little complicated.

result = final_func(**arguments)
if wrap_classes and isinstance(result, _WRAPPED_CLASSES):
result = _CFWrappedClass(result, accessor)

Expand Down Expand Up @@ -312,21 +325,58 @@ def __init__(self, obj, accessor):
self.accessor = accessor
self._keys = ("x", "y", "hue", "col", "row")

def _plot_decorator(self, func):
"""
This decorator is used to set kwargs on plotting functions.
"""
valid_keys = self.accessor.get_valid_keys()

@functools.wraps(func)
def _plot_wrapper(*args, **kwargs):
if "x" in kwargs:
if kwargs["x"] in valid_keys:
xvar = self.accessor[kwargs["x"]]
else:
xvar = self._obj[kwargs["x"]]
if "positive" in xvar.attrs:
if xvar.attrs["positive"] == "down":
kwargs.setdefault("xincrease", False)
else:
kwargs.setdefault("xincrease", True)

if "y" in kwargs:
if kwargs["y"] in valid_keys:
yvar = self.accessor[kwargs["y"]]
else:
yvar = self._obj[kwargs["y"]]
if "positive" in yvar.attrs:
if yvar.attrs["positive"] == "down":
kwargs.setdefault("yincrease", False)
else:
kwargs.setdefault("yincrease", True)

return func(*args, **kwargs)

return _plot_wrapper

def __call__(self, *args, **kwargs):
plot = _getattr(
obj=self._obj,
attr="plot",
accessor=self.accessor,
key_mappers=dict.fromkeys(self._keys, _get_axis_coord_single),
)
return plot(*args, **kwargs)
return self._plot_decorator(plot)(*args, **kwargs)

def __getattr__(self, attr):
return _getattr(
obj=self._obj.plot,
attr=attr,
accessor=self.accessor,
key_mappers=dict.fromkeys(self._keys, _get_axis_coord_single),
# TODO: "extra_decorator" is more complex than I would like it to be.
# Not sure if there is a better way though
extra_decorator=self._plot_decorator,
)


Expand Down Expand Up @@ -458,6 +508,29 @@ def _describe(self):
def describe(self):
print(self._describe())

def get_valid_keys(self) -> Set[str]:
"""
Returns valid keys for .cf[]

Returns
-------
Set of valid key names that can be used with __getitem__ or .cf[key].
"""
varnames = [
key
for key in _AXIS_NAMES + _COORD_NAMES
if _get_axis_coord(self._obj, key, error=False, default=None) != [None]
]
with suppress(NotImplementedError):
measures = [
key
for key in _CELL_MEASURES
if _get_measure(self._obj, key, error=False) is not None
]
if measures:
varnames.append(*measures)
return set(varnames)


@xr.register_dataset_accessor("cf")
class CFDatasetAccessor(CFAccessor):
Expand Down
29 changes: 29 additions & 0 deletions cf_xarray/tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
ds.coords["cell_area"] = (
xr.DataArray(np.cos(ds.lat * np.pi / 180)) * xr.ones_like(ds.lon) * 105e3 * 110e3
)
ds_no_attrs = ds.copy(deep=True)
for variable in ds_no_attrs.variables:
ds_no_attrs[variable].attrs = {}

datasets = [ds, ds.chunk({"lat": 5})]
dataarrays = [ds.air, ds.air.chunk({"lat": 5})]
objects = datasets + dataarrays
Expand Down Expand Up @@ -121,6 +125,19 @@ def test_kwargs_expand_key_to_multiple_keys():
assert_identical(actual.mean(), expected.mean())


@pytest.mark.parametrize(
"obj, expected",
[
(ds, set(("latitude", "longitude", "time", "X", "Y", "T"))),
(ds.air, set(("latitude", "longitude", "time", "X", "Y", "T", "area"))),
(ds_no_attrs.air, set()),
],
)
def test_get_valid_keys(obj, expected):
actual = obj.cf.get_valid_keys()
assert actual == expected


@pytest.mark.parametrize("obj", objects)
def test_args_methods(obj):
with raise_if_dask_computes():
Expand Down Expand Up @@ -234,3 +251,15 @@ def test_getitem_uses_coordinates():
)
assert_identical(ds.UVEL.cf["X"], ds["ULONG"].reset_coords(drop=True))
assert_identical(ds.TEMP.cf["X"], ds["TLONG"].reset_coords(drop=True))


def test_plot_xincrease_yincrease():
ds = xr.tutorial.open_dataset("air_temperature").isel(time=slice(4), lon=slice(50))
ds.lon.attrs["positive"] = "down"
ds.lat.attrs["positive"] = "down"

f, ax = plt.subplots(1, 1)
ds.air.isel(time=1).cf.plot(ax=ax, x="X", y="Y")

for lim in [ax.get_xlim(), ax.get_ylim()]:
assert lim[0] > lim[1]