Skip to content

Commit

Permalink
Use zarr to validate attrs when writing to zarr (#6636)
Browse files Browse the repository at this point in the history
Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
  • Loading branch information
malmans2 and dcherian committed Jun 3, 2022
1 parent 7bdb0e4 commit b080349
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 4 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ Deprecations
Bug fixes
~~~~~~~~~

- :py:meth:`Dataset.to_zarr` now allows to write all attribute types supported by `zarr-python`.
By `Mattia Almansi <https://github.com/malmans2>`_.
- Set ``skipna=None`` for all ``quantile`` methods (e.g. :py:meth:`Dataset.quantile`) and
ensure it skips missing values for float dtypes (consistent with other methods). This should
not change the behavior (:pull:`6303`).
Expand Down
3 changes: 1 addition & 2 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1561,9 +1561,8 @@ def to_zarr(
f"'w-', 'a' and 'r+', but mode={mode!r}"
)

# validate Dataset keys, DataArray names, and attr keys/values
# validate Dataset keys, DataArray names
_validate_dataset_names(dataset)
_validate_attrs(dataset)

if region is not None:
_validate_region(dataset, region)
Expand Down
13 changes: 11 additions & 2 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,15 @@ def _validate_existing_dims(var_name, new_var, existing_var, region, append_dim)
)


def _put_attrs(zarr_obj, attrs):
"""Raise a more informative error message for invalid attrs."""
try:
zarr_obj.attrs.put(attrs)
except TypeError as e:
raise TypeError("Invalid attribute in Dataset.attrs.") from e
return zarr_obj


class ZarrStore(AbstractWritableDataStore):
"""Store for reading and writing data via zarr"""

Expand Down Expand Up @@ -479,7 +488,7 @@ def set_dimensions(self, variables, unlimited_dims=None):
)

def set_attributes(self, attributes):
self.zarr_group.attrs.put(attributes)
_put_attrs(self.zarr_group, attributes)

def encode_variable(self, variable):
variable = encode_zarr_variable(variable)
Expand Down Expand Up @@ -618,7 +627,7 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
zarr_array = self.zarr_group.create(
name, shape=shape, dtype=dtype, fill_value=fill_value, **encoding
)
zarr_array.attrs.put(encoded_attrs)
zarr_array = _put_attrs(zarr_array, encoded_attrs)

write_region = self._write_region if self._write_region is not None else {}
write_region = {dim: write_region.get(dim, slice(None)) for dim in dims}
Expand Down
16 changes: 16 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2443,6 +2443,22 @@ def test_write_read_select_write(self):
with self.create_zarr_target() as final_store:
ds_sel.to_zarr(final_store, mode="w")

@pytest.mark.parametrize("obj", [Dataset(), DataArray(name="foo")])
def test_attributes(self, obj):
obj = obj.copy()

obj.attrs["good"] = {"key": "value"}
ds = obj if isinstance(obj, Dataset) else obj.to_dataset()
with self.create_zarr_target() as store_target:
ds.to_zarr(store_target)
assert_identical(ds, xr.open_zarr(store_target))

obj.attrs["bad"] = DataArray()
ds = obj if isinstance(obj, Dataset) else obj.to_dataset()
with self.create_zarr_target() as store_target:
with pytest.raises(TypeError, match=r"Invalid attribute in Dataset.attrs."):
ds.to_zarr(store_target)


@requires_zarr
class TestZarrDictStore(ZarrBase):
Expand Down

0 comments on commit b080349

Please sign in to comment.