Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make sure re-quantifying dimension coordinates works #174

Merged
merged 6 commits into from
Jun 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ What's new
By `Justus Magin <https://github.com/keewis>`_.
- preserve :py:class:`pandas.MultiIndex` objects (:issue:`164`, :pull:`168`).
By `Justus Magin <https://github.com/keewis>`_.
- fix "quantifying" dimension coordinates (:issue:`105`, :pull:`174`).
By `Justus Magin <https://github.com/keewis>`_.

0.2.1 (26 Jul 2021)
-------------------
Expand Down
55 changes: 54 additions & 1 deletion pint_xarray/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,10 @@ def _decide_units(units, registry, unit_attribute):
elif units is _default:
if unit_attribute in no_unit_values:
return unit_attribute
units = registry.parse_units(unit_attribute)
if isinstance(unit_attribute, Unit):
units = unit_attribute
else:
units = registry.parse_units(unit_attribute)
else:
units = registry.parse_units(units)
return units
Expand Down Expand Up @@ -360,6 +363,31 @@ def quantify(self, units=_default, unit_registry=None, **unit_kwargs):
if invalid_units:
raise ValueError(format_error_message(invalid_units, "parse"))

existing_units = {
name: unit
for name, unit in conversion.extract_units(self.da).items()
if isinstance(unit, Unit)
}
overwritten_units = {
name: (old, new)
for name, (old, new) in zip_mappings(
existing_units, new_units, fill_value=_default
).items()
if old is not _default and new is not _default
}
if overwritten_units:
errors = {
name: (
new,
ValueError(
f"Cannot attach unit {repr(new)} to quantity: data "
f"already has units {repr(old)}"
),
)
for name, (old, new) in overwritten_units.items()
}
raise ValueError(format_error_message(errors, "attach"))

return self.da.pipe(conversion.strip_unit_attributes).pipe(
conversion.attach_units, new_units
)
Expand Down Expand Up @@ -1050,6 +1078,31 @@ def quantify(self, units=_default, unit_registry=None, **unit_kwargs):
if invalid_units:
raise ValueError(format_error_message(invalid_units, "parse"))

existing_units = {
name: unit
for name, unit in conversion.extract_units(self.ds).items()
if isinstance(unit, Unit)
}
overwritten_units = {
name: (old, new)
for name, (old, new) in zip_mappings(
existing_units, new_units, fill_value=_default
).items()
if old is not _default and new is not _default
}
if overwritten_units:
errors = {
name: (
new,
ValueError(
f"Cannot attach unit {repr(new)} to quantity: data "
f"already has units {repr(old)}"
),
)
for name, (old, new) in overwritten_units.items()
}
raise ValueError(format_error_message(errors, "attach"))

return self.ds.pipe(conversion.strip_unit_attributes).pipe(
conversion.attach_units, new_units
)
Expand Down
30 changes: 30 additions & 0 deletions pint_xarray/tests/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,22 @@ def test_parse_integer_inverse(self):
result = da.pint.quantify()
assert result.pint.units == Unit("1 / meter")

def test_dimension_coordinate(self):
ds = xr.Dataset(coords={"x": ("x", [10], {"units": "m"})})
arr = ds.x

# does not actually quantify because `arr` wraps a IndexVariable
# but we still get a `Unit` in the attrs
q = arr.pint.quantify()
assert isinstance(q.attrs["units"], Unit)

def test_dimension_coordinate_already_quantified(self):
ds = xr.Dataset(coords={"x": ("x", [10], {"units": unit_registry.Unit("m")})})
arr = ds.x

with pytest.raises(ValueError):
arr.pint.quantify({"x": "s"})


@pytest.mark.parametrize("formatter", ("", "P", "C"))
@pytest.mark.parametrize("modifier", ("", "~"))
Expand Down Expand Up @@ -313,6 +329,20 @@ def test_error_indicates_problematic_variable(self, example_unitless_ds):
with pytest.raises(ValueError, match="'users'"):
ds.pint.quantify(units={"users": "aecjhbav"})

def test_existing_units(self, example_quantity_ds):
ds = example_quantity_ds.copy()
ds.t.attrs["units"] = unit_registry.Unit("m")

with pytest.raises(ValueError, match="Cannot attach"):
ds.pint.quantify({"funds": "kg"})

def test_existing_units_dimension(self, example_quantity_ds):
ds = example_quantity_ds.copy()
ds.t.attrs["units"] = unit_registry.Unit("m")

with pytest.raises(ValueError, match="Cannot attach"):
ds.pint.quantify({"t": "s"})


class TestDequantifyDataSet:
def test_strip_units(self, example_quantity_ds):
Expand Down