Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement drop_sel #73

Merged
merged 12 commits into from
Mar 3, 2021
2 changes: 2 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Dataset
xarray.Dataset.pint.interp_like
xarray.Dataset.pint.reindex
xarray.Dataset.pint.reindex_like
xarray.Dataset.pint.drop_sel
xarray.Dataset.pint.sel
xarray.Dataset.pint.to

Expand All @@ -43,6 +44,7 @@ DataArray
xarray.DataArray.pint.interp_like
xarray.DataArray.pint.reindex
xarray.DataArray.pint.reindex_like
xarray.DataArray.pint.drop_sel
xarray.DataArray.pint.sel
xarray.DataArray.pint.to

Expand Down
2 changes: 2 additions & 0 deletions docs/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ What's new
By `Mika Pflüger <https://github.com/mikapfl>`_.
- implement :py:meth:`Dataset.pint.sel` and :py:meth:`DataArray.pint.sel` (:pull:`60`).
By `Justus Magin <https://github.com/keewis>`_.
- implement :py:meth:`Dataset.pint.drop_sel` and :py:meth:`DataArray.pint.drop_sel` (:pull:`73`).
By `Justus Magin <https://github.com/keewis>`_.
- implement :py:meth:`Dataset.pint.reindex`, :py:meth:`Dataset.pint.reindex_like`,
:py:meth:`DataArray.pint.reindex` and :py:meth:`DataArray.pint.reindex_like` (:pull:`69`).
By `Justus Magin <https://github.com/keewis>`_.
Expand Down
120 changes: 120 additions & 0 deletions pint_xarray/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,66 @@ def sel(
def loc(self):
...

def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs):
"""unit-aware version of drop_sel

Just like :py:meth:`xarray.DataArray.drop_sel`, except the indexers are converted
to the units of the object's indexes first.

See Also
--------
xarray.Dataset.pint.drop_sel
xarray.DataArray.drop_sel
xarray.Dataset.drop_sel
"""
indexers = either_dict_or_kwargs(labels, labels_kwargs, "drop_sel")

indexer_units = {
name: conversion.extract_indexer_units(indexer)
for name, indexer in indexers.items()
}

# make sure we only have compatible units
dims = self.da.dims
unit_attrs = conversion.extract_unit_attributes(self.da)
index_units = {
name: units for name, units in unit_attrs.items() if name in dims
}

registry = get_registry(None, index_units, indexer_units)

units = zip_mappings(indexer_units, index_units)
incompatible_units = [
key
for key, (indexer_unit, index_unit) in units.items()
if (
None not in (indexer_unit, index_unit)
and not registry.is_compatible_with(indexer_unit, index_unit)
)
]
if incompatible_units:
units1 = {key: indexer_units[key] for key in incompatible_units}
units2 = {key: index_units[key] for key in incompatible_units}
raise DimensionalityError(units1, units2)

# convert the indexers to the indexes units
converted_indexers = {
name: conversion.convert_indexer_units(indexer, index_units[name])
for name, indexer in indexers.items()
}

# index
stripped_indexers = {
name: conversion.strip_indexer_units(indexer)
for name, indexer in converted_indexers.items()
}
indexed = self.da.drop_sel(
stripped_indexers,
errors=errors,
)

return indexed


@register_dataset_accessor("pint")
class PintDatasetAccessor:
Expand Down Expand Up @@ -1325,3 +1385,63 @@ def sel(
@property
def loc(self):
...

def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs):
"""unit-aware version of drop_sel

Just like :py:meth:`xarray.Dataset.drop_sel`, except the indexers are converted
to the units of the object's indexes first.

See Also
--------
xarray.DataArray.pint.drop_sel
xarray.Dataset.drop_sel
xarray.DataArray.drop_sel
"""
indexers = either_dict_or_kwargs(labels, labels_kwargs, "drop_sel")

indexer_units = {
name: conversion.extract_indexer_units(indexer)
for name, indexer in indexers.items()
}

# make sure we only have compatible units
dims = self.ds.dims
unit_attrs = conversion.extract_unit_attributes(self.ds)
index_units = {
name: units for name, units in unit_attrs.items() if name in dims
}

registry = get_registry(None, index_units, indexer_units)

units = zip_mappings(indexer_units, index_units)
incompatible_units = [
key
for key, (indexer_unit, index_unit) in units.items()
if (
None not in (indexer_unit, index_unit)
and not registry.is_compatible_with(indexer_unit, index_unit)
)
]
if incompatible_units:
units1 = {key: indexer_units[key] for key in incompatible_units}
units2 = {key: index_units[key] for key in incompatible_units}
raise DimensionalityError(units1, units2)

# convert the indexers to the indexes units
converted_indexers = {
name: conversion.convert_indexer_units(indexer, index_units[name])
for name, indexer in indexers.items()
}

# index
stripped_indexers = {
name: conversion.strip_indexer_units(indexer)
for name, indexer in converted_indexers.items()
}
indexed = self.ds.drop_sel(
stripped_indexers,
errors=errors,
)

return indexed
25 changes: 24 additions & 1 deletion pint_xarray/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from xarray import DataArray, Dataset, IndexVariable, Variable

unit_attribute_name = "units"
slice_attributes = ("start", "stop", "step")


def array_attach_units(data, unit):
Expand Down Expand Up @@ -306,7 +307,7 @@ def strip_unit_attributes(obj, attr="units"):


def slice_extract_units(indexer):
elements = {name: getattr(indexer, name) for name in ("start", "stop", "step")}
elements = {name: getattr(indexer, name) for name in slice_attributes}
extracted_units = [
array_extract_units(value)
for name, value in elements.items()
Expand All @@ -333,6 +334,28 @@ def slice_extract_units(indexer):
return registry.Quantity(1, units_).to_base_units().units


def convert_units_slice(indexer, units):
attrs = {name: getattr(indexer, name) for name in slice_attributes}
converted = {
name: array_convert_units(value, units) if value is not None else None
for name, value in attrs.items()
}
args = [converted[name] for name in slice_attributes]

return slice(*args)


def convert_indexer_units(indexer, units):
if isinstance(indexer, slice):
return convert_units_slice(indexer, units)
elif isinstance(indexer, DataArray):
return convert_units(indexer, {None: units})
elif isinstance(indexer, Variable):
return convert_units_variable(indexer, units)
else:
return array_convert_units(indexer, units)


def extract_indexer_units(indexer):
if isinstance(indexer, slice):
return slice_extract_units(indexer)
Expand Down
143 changes: 143 additions & 0 deletions pint_xarray/tests/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,149 @@ def test_sel(obj, indexers, expected, error):
assert_identical(actual, expected)


@pytest.mark.parametrize(
["obj", "indexers", "expected", "error"],
(
pytest.param(
xr.Dataset(
{
"x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}),
"y": ("y", [60, 120], {"units": unit_registry.Unit("s")}),
}
),
{"x": Quantity([10, 30], "dm"), "y": Quantity([60], "s")},
xr.Dataset(
{
"x": ("x", [20], {"units": unit_registry.Unit("dm")}),
"y": ("y", [120], {"units": unit_registry.Unit("s")}),
}
),
None,
id="Dataset-identical units",
),
pytest.param(
xr.Dataset(
{
"x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}),
"y": ("y", [60, 120], {"units": unit_registry.Unit("s")}),
}
),
{"x": Quantity([1, 3], "m"), "y": Quantity([1], "min")},
xr.Dataset(
{
"x": ("x", [20], {"units": unit_registry.Unit("dm")}),
"y": ("y", [120], {"units": unit_registry.Unit("s")}),
}
),
None,
id="Dataset-compatible units",
),
pytest.param(
xr.Dataset(
{
"x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}),
"y": ("y", [60, 120], {"units": unit_registry.Unit("s")}),
}
),
{"x": Quantity([1, 3], "s"), "y": Quantity([1], "m")},
None,
DimensionalityError,
id="Dataset-incompatible units",
),
pytest.param(
xr.Dataset(
{
"x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}),
"y": ("y", [60, 120], {"units": unit_registry.Unit("s")}),
}
),
{"x": Quantity([10, 30], "m"), "y": Quantity([60], "min")},
None,
KeyError,
id="Dataset-compatible units-not found",
),
pytest.param(
xr.DataArray(
[[0, 1], [2, 3], [4, 5]],
dims=("x", "y"),
coords={
"x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}),
"y": ("y", [60, 120], {"units": unit_registry.Unit("s")}),
},
),
{"x": Quantity([10, 30], "dm"), "y": Quantity([60], "s")},
xr.DataArray(
[[3]],
dims=("x", "y"),
coords={
"x": ("x", [20], {"units": unit_registry.Unit("dm")}),
"y": ("y", [120], {"units": unit_registry.Unit("s")}),
},
),
None,
id="DataArray-identical units",
),
pytest.param(
xr.DataArray(
[[0, 1], [2, 3], [4, 5]],
dims=("x", "y"),
coords={
"x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}),
"y": ("y", [60, 120], {"units": unit_registry.Unit("s")}),
},
),
{"x": Quantity([1, 3], "m"), "y": Quantity([1], "min")},
xr.DataArray(
[[3]],
dims=("x", "y"),
coords={
"x": ("x", [20], {"units": unit_registry.Unit("dm")}),
"y": ("y", [120], {"units": unit_registry.Unit("s")}),
},
),
None,
id="DataArray-compatible units",
),
pytest.param(
xr.DataArray(
[[0, 1], [2, 3], [4, 5]],
dims=("x", "y"),
coords={
"x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}),
"y": ("y", [60, 120], {"units": unit_registry.Unit("s")}),
},
),
{"x": Quantity([10, 30], "s"), "y": Quantity([60], "m")},
None,
DimensionalityError,
id="DataArray-incompatible units",
),
pytest.param(
xr.DataArray(
[[0, 1], [2, 3], [4, 5]],
dims=("x", "y"),
coords={
"x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}),
"y": ("y", [60, 120], {"units": unit_registry.Unit("s")}),
},
),
{"x": Quantity([10, 30], "m"), "y": Quantity([60], "min")},
None,
KeyError,
id="DataArray-compatible units-not found",
),
),
)
def test_drop_sel(obj, indexers, expected, error):
if error is not None:
with pytest.raises(error):
obj.pint.drop_sel(indexers)
else:
actual = obj.pint.drop_sel(indexers)
assert_units_equal(actual, expected)
assert_identical(actual, expected)


@pytest.mark.parametrize(
["obj", "indexers", "expected", "error"],
(
Expand Down