Skip to content

Commit

Permalink
Move metadata consolidation into partition directories (#312)
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Mar 20, 2024
1 parent 3faa157 commit 66c7944
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 23 deletions.
1 change: 1 addition & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ History
X.Y.Z (YYYY-MM-DD)
------------------
* Only test Github Action Push events on master (:pr:`313`)
* Move consolidated metadata into partition subdirectories (:pr:`312`)
* Set ``_ARRAY_DIMENSIONS`` attribute on Data Variables (:pr:`311`)
* Use JSON codec for writing zarr strings (:pr:`310`)
* Address warnings (:pr:`309`)
Expand Down
45 changes: 32 additions & 13 deletions daskms/experimental/zarr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import reduce
from operator import mul
from pathlib import Path
import os.path
import warnings

import dask
Expand Down Expand Up @@ -108,6 +109,7 @@ def create_array(ds_group, column, column_schema, schema_chunks, coordinate=Fals
column,
column_schema.shape,
chunks=zchunks,
fill_value=None,
dtype=column_schema.dtype,
object_codec=codec,
exact=True,
Expand Down Expand Up @@ -341,13 +343,14 @@ def xds_to_zarr(xds, store, columns=None, rechunk=False, consolidated=True, **kw
**{k: getattr(ds, k) for k, _ in partition},
}

write_datasets.append(Dataset(data_vars, attrs=attrs))
if consolidated:
table_name = store.table if store.table else "MAIN"
sep = store.fs.sep
store_path = f"{store.root}{sep}{table_name}{sep}{table_name}_{di}"
store_map = store.fs.get_mapper(store_path)
zc.consolidate_metadata(store_map)

if consolidated:
table_path = store.table if store.table else "MAIN"
store_path = f"{store.root}{store.fs.sep}{table_path}"
store_map = store.fs.get_mapper(store_path)
zc.consolidate_metadata(store_map)
write_datasets.append(Dataset(data_vars, attrs=attrs))

return write_datasets

Expand Down Expand Up @@ -420,16 +423,32 @@ def xds_from_zarr(store, columns=None, chunks=None, consolidated=True, **kwargs)

datasets = []
numpy_vars = []
table_path = store.table if store.table else "MAIN"
table_name = store.table if store.table else "MAIN"

store_map = store.fs.get_mapper(f"{store.root}{store.fs.sep}{table_path}")
store_path = f"{store.root}{store.fs.sep}{table_name}"
store_map = store.fs.get_mapper(store_path)

try:
table_group = zarr.open_consolidated(store_map, mode="r")
except KeyError:
table_group = zarr.open_group(store_map, mode="r")
partition_ids = []

for entry in store_map.fs.listdir(f"{store_map.root}"):
if entry["type"] == "directory":
_, dir_name = os.path.split(entry["name"])
if dir_name.startswith(table_name):
_, i = dir_name.split("_")
partition_ids.append(int(i))

for g in sorted(partition_ids):
group_path = f"{store_path}{store.fs.sep}{table_name}_{g}"
group_map = store.fs.get_mapper(group_path)

if consolidated:
try:
group = zarr.open_consolidated(group_map, mode="r")
except KeyError:
group = zarr.open_group(group_map, mode="r")
else:
group = zarr.open_group(group_map, mode="r")

for g, (_, group) in enumerate(sorted(table_group.groups(), key=group_sortkey)):
group_attrs = decode_attr(dict(group.attrs))
dask_ms_attrs = group_attrs.pop(DASKMS_ATTR_KEY)
natural_chunks = dask_ms_attrs["chunks"]
Expand Down
52 changes: 42 additions & 10 deletions daskms/experimental/zarr/tests/test_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from daskms import xds_from_ms, xds_from_table, xds_from_storage_ms
from daskms.constants import DASKMS_PARTITION_KEY
from daskms.dataset import Dataset
from daskms.experimental.zarr import xds_from_zarr, xds_to_zarr
from daskms.experimental.zarr import xds_from_zarr, xds_to_zarr, DASKMS_ATTR_KEY
from daskms.fsspec_store import DaskMSStore, UnknownStoreTypeError

try:
Expand Down Expand Up @@ -75,11 +75,11 @@ def test_xds_to_zarr_coords(tmp_path_factory):

@pytest.mark.parametrize("consolidated", [True, False])
def test_metadata_consolidation(ms, ant_table, tmp_path_factory, consolidated):
zarr_dir = tmp_path_factory.mktemp("zarr_store") / "test.zarr"
ant_dir = zarr_dir.parent / f"{zarr_dir.name}::ANTENNA"
zarr_path = tmp_path_factory.mktemp("zarr_store") / "test.zarr"
ant_path = zarr_path.parent / f"{zarr_path.name}::ANTENNA"

main_store = DaskMSStore(zarr_dir)
ant_store = DaskMSStore(ant_dir)
main_store = DaskMSStore(zarr_path)
ant_store = DaskMSStore(ant_path)

ms_datasets = xds_from_ms(ms)
ant_datasets = xds_from_table(ant_table)
Expand All @@ -95,14 +95,14 @@ def test_metadata_consolidation(ms, ant_table, tmp_path_factory, consolidated):
writes.extend(xds_to_zarr(ant_datasets, ant_store, consolidated=consolidated))
dask.compute(writes)

assert main_store.exists("MAIN/.zmetadata") is consolidated
assert ant_store.exists(".zmetadata") is consolidated
assert main_store.exists("MAIN/MAIN_0/.zmetadata") is consolidated
assert ant_store.exists("ANTENNA_0/.zmetadata") is consolidated

if consolidated:
with main_store.open("MAIN/.zmetadata") as f:
with main_store.open("MAIN/MAIN_0/.zmetadata") as f:
assert "test-meta".encode("utf8") in f.read()

with ant_store.open(".zmetadata") as f:
with ant_store.open("ANTENNA_0/.zmetadata") as f:
assert "test-meta".encode("utf8") in f.read()

for ds in xds_from_zarr(main_store, consolidated=consolidated):
Expand Down Expand Up @@ -429,6 +429,38 @@ def test_xarray_reading_daskms_written_dataset(ms, tmp_path_factory):
path = store / "test.zarr"
dask.compute(xds_to_zarr(datasets, path, consolidated=True))

extra = tmp_path_factory.mktemp("extra")
datasets[0].to_zarr(extra)

for i, mem_ds in enumerate(datasets):
ds = xarray.open_zarr(path / "MAIN" / f"MAIN_{i}")
assert ds == mem_ds

for k in (set(ds.data_vars) | set(mem_ds.data_vars)) - {"ROWID"}:
assert_array_equal(ds.data_vars[k], mem_ds.data_vars[k])

for k in (set(ds.coords) | set(mem_ds.coords)) - {"ROWID"}:
assert_array_equal(ds.coords[k], mem_ds.coords[k])

# capitalised ROWID breaks the lowercase
# xarray coordinate naming convention
# so xarray treats it as a data variable
assert_array_equal(ds.data_vars["ROWID"], mem_ds.coords["ROWID"])

attr_keys = (set(mem_ds.attrs) | set(ds.attrs)) - {
DASKMS_PARTITION_KEY,
DASKMS_ATTR_KEY,
}

for k in attr_keys:
assert ds.attrs[k] == mem_ds.attrs[k]

# DASKMS_ATTR_KEY is added by dask-ms when writing
# so xarray reads it in
assert DASKMS_ATTR_KEY in ds.attrs
assert DASKMS_ATTR_KEY not in mem_ds.attrs

# xarray converts tuples to list when json encoding
assert (
tuple(map(tuple, ds.attrs[DASKMS_PARTITION_KEY]))
== mem_ds.attrs[DASKMS_PARTITION_KEY]
)

0 comments on commit 66c7944

Please sign in to comment.