Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ repos:
files: .+\.py$
# https://github.com/python/black#version-control-integration
- repo: https://github.com/python/black
rev: stable
rev: 19.10b0
hooks:
- id: black
- repo: https://gitlab.com/pycqa/flake8
rev: 3.7.9
rev: 3.8.3
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.761 # Must match ci/requirements/*.yml
rev: v0.781 # Must match ci/requirements/*.yml
hooks:
- id: mypy
# run this occasionally, ref discussion https://github.com/pydata/xarray/pull/3194
Expand Down
161 changes: 138 additions & 23 deletions cf_xarray/accessor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
import functools
import inspect
import itertools
import textwrap
from collections import ChainMap
from contextlib import suppress
from typing import Callable, List, Mapping, MutableMapping, Optional, Set, Tuple, Union
from typing import (
Callable,
Hashable,
List,
Mapping,
MutableMapping,
Optional,
Set,
Tuple,
Union,
)

import xarray as xr
from xarray import DataArray, Dataset
Expand Down Expand Up @@ -106,6 +117,11 @@
]


def _strip_none_list(lst: List[Optional[str]]) -> List[str]:
""" The mappers can return [None]. Strip that when necessary. Keeps mypy happy."""
return [item for item in lst if item != [None]] # type: ignore


def _get_axis_coord_single(
var: Union[xr.DataArray, xr.Dataset],
key: str,
Expand Down Expand Up @@ -176,8 +192,8 @@ def _get_axis_coord(
results: Set = set()
for coord in search_in:
for criterion, valid_values in coordinate_criteria.items():
if key in valid_values: # type: ignore
expected = valid_values[key] # type: ignore
if key in valid_values:
expected = valid_values[key]
if var.coords[coord].attrs.get(criterion, None) in expected:
results.update((coord,))

Expand Down Expand Up @@ -246,6 +262,31 @@ def _get_measure(
}


def _filter_by_standard_names(ds: xr.Dataset, name: Union[str, List[str]]) -> List[str]:
""" returns a list of variable names with standard names matching name. """
if isinstance(name, str):
name = [name]

varnames = []
counts = dict.fromkeys(name, 0)
for vname, var in ds.variables.items():
stdname = var.attrs.get("standard_name", None)
if stdname in name:
varnames.append(str(vname))
counts[stdname] += 1

return varnames


def _get_list_standard_names(obj: xr.Dataset) -> List[str]:
""" Returns a sorted list of standard names in Dataset. """
names = []
for k, v in obj.variables.items():
if "standard_name" in v.attrs:
names.append(v.attrs["standard_name"])
return sorted(names)


def _getattr(
obj: Union[DataArray, Dataset],
attr: str,
Expand Down Expand Up @@ -503,6 +544,16 @@ def _describe(self):
text += f"\t{measure}: unsupported\n"
else:
text += f"\t{measure}: {_get_measure(self._obj, measure, error=False, default=None)}\n"

text += "\nStandard Names:\n"
if isinstance(self._obj, xr.DataArray):
text += "\tunsupported\n"
else:
stdnames = _get_list_standard_names(self._obj)
text += "\t"
text += "\n".join(
textwrap.wrap(f"{stdnames!r}", 70, break_long_words=False)
)
return text

def describe(self):
Expand All @@ -529,32 +580,96 @@ def get_valid_keys(self) -> Set[str]:
]
if measures:
varnames.append(*measures)

if not isinstance(self._obj, xr.DataArray):
varnames.extend(_get_list_standard_names(self._obj))
return set(varnames)

def __getitem__(self, key: Union[str, List[str]]):

kind = str(type(self._obj).__name__)
scalar_key = isinstance(key, str)
if scalar_key:
key = (key,) # type: ignore

varnames: List[Hashable] = []
coords: List[Hashable] = []
successful = dict.fromkeys(key, False)
for k in key:
if k in _AXIS_NAMES + _COORD_NAMES:
names = _get_axis_coord(self._obj, k)
successful[k] = bool(names)
varnames.extend(_strip_none_list(names))
coords.extend(_strip_none_list(names))
elif k in _CELL_MEASURES:
if isinstance(self._obj, xr.Dataset):
raise NotImplementedError(
"Invalid key {k!r}. Cell measures not implemented for Dataset yet."
)
else:
measure = _get_measure(self._obj, k)
successful[k] = bool(measure)
if measure:
varnames.append(measure)
elif not isinstance(self._obj, xr.DataArray):
stdnames = _filter_by_standard_names(self._obj, k)
successful[k] = bool(stdnames)
varnames.extend(stdnames)
coords.extend(list(set(stdnames).intersection(set(self._obj.coords))))

# these are not special names but could be variable names in underlying object
# we allow this so that we can return variables with appropriate CF auxiliary variables
varnames.extend([k for k, v in successful.items() if not v])
assert len(varnames) > 0

try:
# TODO: make this a get_auxiliary_variables function
# make sure to set coordinate variables referred to in "coordinates" attribute
for name in varnames:
attrs = self._obj[name].attrs
if "coordinates" in attrs:
coords.extend(attrs.get("coordinates").split(" "))

if "cell_measures" in attrs:
measures = [
_get_measure(self._obj[name], measure)
for measure in _CELL_MEASURES
if measure in attrs["cell_measures"]
]
coords.extend(_strip_none_list(measures))

varnames.extend(coords)
if isinstance(self._obj, xr.DataArray):
ds = self._obj._to_temp_dataset()
else:
ds = self._obj
ds = ds.reset_coords()[varnames]
if isinstance(self._obj, DataArray):
if scalar_key and len(ds.variables) == 1:
# single dimension coordinates
return ds[list(ds.variables.keys())[0]].squeeze(drop=True)
elif scalar_key and len(ds.coords) > 1:
raise NotImplementedError(
"Not sure what to return when given scalar key for DataArray and it has multiple values. "
"Please open an issue."
)
elif not scalar_key:
return ds.set_coords(coords)
else:
return ds.set_coords(coords)

except KeyError:
raise KeyError(
f"{kind}.cf does not understand the key {k!r}. "
f"Use {kind}.cf.describe() to see a list of key names that can be interpreted."
)


@xr.register_dataset_accessor("cf")
class CFDatasetAccessor(CFAccessor):
def __getitem__(self, key):
if key in _AXIS_NAMES + _COORD_NAMES:
varnames = _get_axis_coord(self._obj, key)
return self._obj.reset_coords()[varnames].set_coords(varnames)
elif key in _CELL_MEASURES:
raise NotImplementedError("measures not implemented for Dataset yet.")
else:
raise KeyError(
f"Dataset.cf does not understand the key {key!r}. Use Dataset.cf.describe() to see a list of key names that can be interpreted."
)
pass


@xr.register_dataarray_accessor("cf")
class CFDataArrayAccessor(CFAccessor):
def __getitem__(self, key):
if key in _AXIS_NAMES + _COORD_NAMES:
varname = _get_axis_coord_single(self._obj, key)
return self._obj[varname].reset_coords(drop=True)
elif key in _CELL_MEASURES:
return self._obj[_get_measure(self._obj, key)]
else:
raise KeyError(
f"DataArray.cf does not understand the key {key!r}. Use DataArray.cf.describe() to see a list of key names that can be interpreted."
)
pass
49 changes: 49 additions & 0 deletions cf_xarray/tests/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np
import xarray as xr

airds = xr.tutorial.open_dataset("air_temperature").isel(time=slice(4), lon=slice(50))
airds.air.attrs["cell_measures"] = "area: cell_area"
airds.air.attrs["standard_name"] = "air_temperature"
airds.coords["cell_area"] = (
xr.DataArray(np.cos(airds.lat * np.pi / 180))
* xr.ones_like(airds.lon)
* 105e3
* 110e3
)

ds_no_attrs = airds.copy(deep=True)
for variable in ds_no_attrs.variables:
ds_no_attrs[variable].attrs = {}


popds = xr.Dataset()
popds.coords["TLONG"] = (
("nlat", "nlon"),
np.ones((20, 30)),
{"axis": "X", "units": "degrees_east"},
)
popds.coords["TLAT"] = (
("nlat", "nlon"),
2 * np.ones((20, 30)),
{"axis": "Y", "units": "degrees_north"},
)
popds.coords["ULONG"] = (
("nlat", "nlon"),
0.5 * np.ones((20, 30)),
{"axis": "X", "units": "degrees_east"},
)
popds.coords["ULAT"] = (
("nlat", "nlon"),
2.5 * np.ones((20, 30)),
{"axis": "Y", "units": "degrees_north"},
)
popds["UVEL"] = (
("nlat", "nlon"),
np.ones((20, 30)) * 15,
{"coordinates": "ULONG ULAT", "standard_name": "sea_water_x_velocity"},
)
popds["TEMP"] = (
("nlat", "nlon"),
np.ones((20, 30)) * 15,
{"coordinates": "TLONG TLAT", "standard_name": "sea_water_potential_temperature"},
)
Loading