Skip to content

Commit

Permalink
Add an epoch argument to xds_{from,to}_zarr to uniquely identify data…
Browse files Browse the repository at this point in the history
…sets in a distributed context (#330)
  • Loading branch information
sjperkins committed May 21, 2024
1 parent 747408f commit abdf497
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 7 deletions.
2 changes: 2 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ History

X.Y.Z (YYYY-MM-DD)
------------------
* Add an epoch argument to xds_{from,to}_zarr to uniquely identify
datasets in a distributed context (:pr:`330`)
* Improve table schema handling (:pr:`329`)
* Identify channel and correlation-like dimensions in non-standard MS columns (:pr:`329`)
* DaskMSStore depends on ``fsspec >= 2022.7.0`` (:pr:`328`)
Expand Down
29 changes: 22 additions & 7 deletions daskms/experimental/zarr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from operator import mul
from pathlib import Path
import os.path
from uuid import uuid4
import warnings

import dask
Expand Down Expand Up @@ -216,7 +217,7 @@ def zarr_setter(data, name, group, *extents):
return np.full((1,) * len(extents), True)


def _gen_writes(variables, chunks, factory, indirect_dims=False):
def _gen_writes(variables, chunks, factory, epoch, indirect_dims=False):
for name, var in variables.items():
if isinstance(var.data, da.Array):
ext_args = extent_args(var.dims, var.chunks)
Expand All @@ -236,7 +237,9 @@ def _gen_writes(variables, chunks, factory, indirect_dims=False):
if var_data.nbytes == 0:
continue

token_name = f"write~{name}-" f"{tokenize(var_data, name, factory, *ext_args)}"
token_name = (
f"write~{name}-" f"{tokenize(var_data, name, factory, epoch, *ext_args)}"
)

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=da.PerformanceWarning)
Expand Down Expand Up @@ -264,7 +267,9 @@ def _gen_writes(variables, chunks, factory, indirect_dims=False):


@requires("pip install dask-ms[zarr] for zarr support", zarr_import_error)
def xds_to_zarr(xds, store, columns=None, rechunk=False, consolidated=True, **kwargs):
def xds_to_zarr(
xds, store, columns=None, rechunk=False, consolidated=True, epoch=None, **kwargs
):
"""
Stores a dataset of list of datasets defined by `xds` in
file location `store`.
Expand All @@ -284,6 +289,9 @@ def xds_to_zarr(xds, store, columns=None, rechunk=False, consolidated=True, **kw
consistent with existing on-disk zarr arrays while writing to disk.
consolidated : bool
Controls whether metadata is consolidated
epoch : str or None
Uniquely identifies this instance of the returned dataset.
Should usually be set to None.
**kwargs : optional
Returns
Expand All @@ -306,6 +314,7 @@ def xds_to_zarr(xds, store, columns=None, rechunk=False, consolidated=True, **kw
UserWarning,
)

epoch = epoch or uuid4().hex[:16]
columns = promote_columns(columns)

if isinstance(xds, Dataset):
Expand All @@ -326,10 +335,10 @@ def xds_to_zarr(xds, store, columns=None, rechunk=False, consolidated=True, **kw

ds, group = prepare_zarr_group(di, ds, store, rechunk=rechunk)

data_vars = dict(_gen_writes(ds.data_vars, ds.chunks, group))
data_vars = dict(_gen_writes(ds.data_vars, ds.chunks, group, epoch))
# Include coords in the write dataset so they're reified
data_vars.update(
dict(_gen_writes(ds.coords, ds.chunks, group, indirect_dims=True))
dict(_gen_writes(ds.coords, ds.chunks, group, epoch, indirect_dims=True))
)

# Transfer any partition information over to the write dataset
Expand Down Expand Up @@ -368,7 +377,9 @@ def group_sortkey(element):


@requires("pip install dask-ms[zarr] for zarr support", zarr_import_error)
def xds_from_zarr(store, columns=None, chunks=None, consolidated=True, **kwargs):
def xds_from_zarr(
store, columns=None, chunks=None, consolidated=True, epoch=None, **kwargs
):
"""
Reads the zarr data store in `store` and returns list of
Dataset's containing the data.
Expand All @@ -384,6 +395,9 @@ def xds_from_zarr(store, columns=None, chunks=None, consolidated=True, **kwargs)
chunking schema for each dataset
consolidated : bool
If True, attempt to read consolidated metadata
epoch : str or None
Uniquely identifies this instance of the returned dataset.
Should usually be set to None.
**kwargs: optional
Returns
Expand All @@ -409,6 +423,7 @@ def xds_from_zarr(store, columns=None, chunks=None, consolidated=True, **kwargs)
UserWarning,
)

epoch = epoch or uuid4().hex[:16]
columns = promote_columns(columns)

if chunks is None:
Expand Down Expand Up @@ -475,7 +490,7 @@ def xds_from_zarr(store, columns=None, chunks=None, consolidated=True, **kwargs)

array_chunks = da.core.normalize_chunks(array_chunks, zarray.shape)
ext_args = extent_args(dims, array_chunks)
token_name = f"read~{name}-{tokenize(zarray, *ext_args)}"
token_name = f"read~{name}-{tokenize(zarray, epoch, *ext_args)}"

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=da.PerformanceWarning)
Expand Down
4 changes: 4 additions & 0 deletions daskms/table_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def __call__(cls, *args, **kwargs):
return _table_cache[key]
except KeyError:
instance = type.__call__(cls, *args, **kwargs)
instance._hashvalue = key
_table_cache[key] = instance
return instance

Expand Down Expand Up @@ -363,6 +364,9 @@ def __init__(self, factory, *args, **kwargs):
def executor_key(self):
return self._ex_key

def __hash__(self):
return self._hashvalue

def __reduce__(self):
"""Defer to _map_create_proxy to support kwarg pickling"""
return (
Expand Down

0 comments on commit abdf497

Please sign in to comment.