Skip to content

Commit

Permalink
Allow target_chunks dict syntax for xarray inputs (#72)
Browse files Browse the repository at this point in the history
* unstaged files

* try to fix black

* try with pre-commit

* Update rechunker/api.py

Co-authored-by: Eric Czech <eric.allen.czech@gmail.com>

Co-authored-by: Eric Czech <eric.allen.czech@gmail.com>
  • Loading branch information
rabernat and eric-czech committed Dec 10, 2020
1 parent 91420d5 commit 15f7e31
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
8 changes: 5 additions & 3 deletions rechunker/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,12 @@ def rechunk(
executor: Union[str, Executor] = "dask",
) -> Rechunked:
"""
Rechunk a Zarr Array or Group, or a Dask Array
Rechunk a Zarr Array or Group, a Dask Array, or an Xarray Dataset
Parameters
----------
source : zarr.Array, zarr.Group, or dask.array.Array
Named dimensions in the Arrays will be parsed according to the
source : zarr.Array, zarr.Group, dask.array.Array, or xarray.Dataset
Named dimensions in the Zarr arrays will be parsed according to the
Xarray :ref:`xarray:zarr_encoding`.
target_chunks : tuple, dict, or None
The desired chunks of the array after rechunking. The structure
Expand Down Expand Up @@ -361,6 +361,8 @@ def _setup_rechunk(
variable, raise_on_invalid=False, name=name
)
variable_chunks = target_chunks.get(name, variable_encoding["chunks"])
if isinstance(variable_chunks, dict):
variable_chunks = _shape_dict_to_tuple(variable.dims, variable_chunks)

# Restrict options to only those that are specific to zarr and
# not managed internally
Expand Down
16 changes: 12 additions & 4 deletions tests/test_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def test_invalid_executor():

@pytest.mark.parametrize("shape", [(100, 50)])
@pytest.mark.parametrize("source_chunks", [(10, 50)])
@pytest.mark.parametrize("target_chunks", [(20, 10)])
@pytest.mark.parametrize(
"target_chunks",
[{"a": (20, 10), "b": (20,)}, {"a": {"x": 20, "y": 10}, "b": {"x": 20}}],
)
@pytest.mark.parametrize("max_mem", ["10MB"])
@pytest.mark.parametrize("executor", ["dask"])
@pytest.mark.parametrize("target_store", ["target.zarr", "mapper.target.zarr"])
Expand Down Expand Up @@ -94,7 +97,7 @@ def test_rechunk_dataset(
)
rechunked = api.rechunk(
ds,
target_chunks=dict(a=target_chunks, b=target_chunks[:1]),
target_chunks=target_chunks,
max_mem=max_mem,
target_store=target_store,
target_options=options,
Expand All @@ -112,8 +115,13 @@ def test_rechunk_dataset(

# Validate decoded variables
dst = xarray.open_zarr(target_store, decode_cf=True)
assert dst.a.data.chunksize == target_chunks
assert dst.b.data.chunksize == target_chunks[:1]
target_chunks_expected = (
target_chunks["a"]
if isinstance(target_chunks["a"], tuple)
else (target_chunks["a"]["x"], target_chunks["a"]["y"])
)
assert dst.a.data.chunksize == target_chunks_expected
assert dst.b.data.chunksize == target_chunks_expected[:1]
assert dst.c.data.chunksize == source_chunks[1:]
xarray.testing.assert_equal(ds.compute(), dst.compute())
assert ds.attrs == dst.attrs
Expand Down

0 comments on commit 15f7e31

Please sign in to comment.