diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b922e7f3949..54a273fbdc3 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -96,6 +96,8 @@ Deprecations Bug fixes ~~~~~~~~~ +- :py:meth:`Dataset.to_zarr` now allows to write all attribute types supported by `zarr-python`. + By `Mattia Almansi `_. - Set ``skipna=None`` for all ``quantile`` methods (e.g. :py:meth:`Dataset.quantile`) and ensure it skips missing values for float dtypes (consistent with other methods). This should not change the behavior (:pull:`6303`). diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 1426cd320d9..cc1e143e573 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1561,9 +1561,8 @@ def to_zarr( f"'w-', 'a' and 'r+', but mode={mode!r}" ) - # validate Dataset keys, DataArray names, and attr keys/values + # validate Dataset keys, DataArray names _validate_dataset_names(dataset) - _validate_attrs(dataset) if region is not None: _validate_region(dataset, region) diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 104f8aca58f..1b8b7ee81e7 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -320,6 +320,15 @@ def _validate_existing_dims(var_name, new_var, existing_var, region, append_dim) ) +def _put_attrs(zarr_obj, attrs): + """Raise a more informative error message for invalid attrs.""" + try: + zarr_obj.attrs.put(attrs) + except TypeError as e: + raise TypeError("Invalid attribute in Dataset.attrs.") from e + return zarr_obj + + class ZarrStore(AbstractWritableDataStore): """Store for reading and writing data via zarr""" @@ -479,7 +488,7 @@ def set_dimensions(self, variables, unlimited_dims=None): ) def set_attributes(self, attributes): - self.zarr_group.attrs.put(attributes) + _put_attrs(self.zarr_group, attributes) def encode_variable(self, variable): variable = encode_zarr_variable(variable) @@ -618,7 +627,7 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No zarr_array = self.zarr_group.create( name, shape=shape, dtype=dtype, fill_value=fill_value, **encoding ) - zarr_array.attrs.put(encoded_attrs) + zarr_array = _put_attrs(zarr_array, encoded_attrs) write_region = self._write_region if self._write_region is not None else {} write_region = {dim: write_region.get(dim, slice(None)) for dim in dims} diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 62c7c1aac31..6f92b26b0c9 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2443,6 +2443,22 @@ def test_write_read_select_write(self): with self.create_zarr_target() as final_store: ds_sel.to_zarr(final_store, mode="w") + @pytest.mark.parametrize("obj", [Dataset(), DataArray(name="foo")]) + def test_attributes(self, obj): + obj = obj.copy() + + obj.attrs["good"] = {"key": "value"} + ds = obj if isinstance(obj, Dataset) else obj.to_dataset() + with self.create_zarr_target() as store_target: + ds.to_zarr(store_target) + assert_identical(ds, xr.open_zarr(store_target)) + + obj.attrs["bad"] = DataArray() + ds = obj if isinstance(obj, Dataset) else obj.to_dataset() + with self.create_zarr_target() as store_target: + with pytest.raises(TypeError, match=r"Invalid attribute in Dataset.attrs."): + ds.to_zarr(store_target) + @requires_zarr class TestZarrDictStore(ZarrBase):