Skip to content

Commit

Permalink
ZarrSimulationStore updates (#117)
Browse files Browse the repository at this point in the history
* chunk and synchronizer encoding opts only at run()

* better handle existing zarr dataset

* default chunk size along batch dim = 1

* ensure no existing dataset (name conflict)
  • Loading branch information
benbovy committed Apr 1, 2020
1 parent 83bbdeb commit e78ed54
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 29 deletions.
62 changes: 40 additions & 22 deletions xsimlab/stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,15 @@ def _get_var_info(

for var_key, clock in var_clocks.items():
var_cache = model._var_cache[var_key]

# encoding defined at model run
run_encoding = normalize_encoding(
encoding.get(var_cache["name"]), extra_keys=["chunks", "synchronizer"]
)

# encoding defined in model variable + update
v_encoding = var_cache["metadata"]["encoding"]
v_encoding.update(normalize_encoding(encoding.get(var_cache["name"])))
v_encoding.update(run_encoding)

var_info[var_key] = {
"clock": clock,
Expand All @@ -41,6 +48,16 @@ def _get_var_info(
return var_info


def ensure_no_dataset_conflict(zgroup, znames):
existing_datasets = [name for name in znames if name in zgroup]

if existing_datasets:
raise ValueError(
f"Zarr path {zgroup.path} already contains the following datasets: "
+ ",".join(existing_datasets)
)


def default_fill_value_from_dtype(dtype=None):
if dtype is None:
return 0
Expand All @@ -55,6 +72,12 @@ def default_fill_value_from_dtype(dtype=None):
return 0


def get_auto_chunks(shape, dtype):
# A hack to get chunks guessed by zarr
arr = zarr.create(shape, dtype=dtype)
return arr.chunks


class ZarrSimulationStore:
def __init__(
self,
Expand Down Expand Up @@ -95,6 +118,10 @@ def __init__(
# initialize clock incrementers
self.clock_incs = self._init_clock_incrementers()

# ensure no dataset conflict in zarr group
znames = [vi["name"] for vi in self.var_info.values()]
ensure_no_dataset_conflict(self.zgroup, znames)

def _init_clock_incrementers(self):
clock_incs = {}

Expand Down Expand Up @@ -133,6 +160,7 @@ def _create_zarr_dataset(

dtype = getattr(value, "dtype", np.asarray(value).dtype)
shape = list(np.shape(value))
chunks = list(get_auto_chunks(shape, dtype))

add_batch_dim = (
self.batch_dim is not None
Expand All @@ -141,25 +169,27 @@ def _create_zarr_dataset(

if clock is not None:
shape.insert(0, self.clock_sizes[clock])
chunks = list(get_auto_chunks(shape, dtype))
if add_batch_dim:
shape.insert(0, self.batch_size)
# by default: chunk of length 1 along batch dimension
chunks.insert(0, 1)

zkwargs = {
"shape": tuple(shape),
"chunks": True,
"chunks": chunks,
"dtype": dtype,
"compressor": "default",
"fill_value": default_fill_value_from_dtype(dtype),
}

zkwargs.update(var_info["encoding"])

# TODO: more performance assessment
# if self.in_memory:
# chunks = False
# compressor = None

zdataset = self.zgroup.create_dataset(name, **zkwargs)
try:
zdataset = self.zgroup.create_dataset(name, **zkwargs)
except ValueError:
# return early if already existing dataset (batches of simulations)
return

# add dimension labels and variable attributes as metadata
dim_labels = None
Expand Down Expand Up @@ -191,18 +221,6 @@ def _create_zarr_dataset(
# reset consolidated since metadata has just been updated
self.consolidated = False

def _maybe_create_zarr_dataset(
self, model: Model, var_key: VarKey, name: Optional[str] = None,
):
# do not create if already exists (only for batches of simulation)
try:
self._create_zarr_dataset(model, var_key, name=name)
except ValueError as err:
if self.batch_dim:
pass
else:
raise err

def _maybe_resize_zarr_dataset(
self, model: Model, var_key: VarKey,
):
Expand Down Expand Up @@ -248,7 +266,7 @@ def write_output_vars(self, batch: int, step: int, model: Optional[Model] = None

if clock_inc == 0:
for vk in var_keys:
self._maybe_create_zarr_dataset(model, vk)
self._create_zarr_dataset(model, vk)

for vk in var_keys:
zkey = self.var_info[vk]["name"]
Expand Down Expand Up @@ -284,7 +302,7 @@ def write_index_vars(self, model: Optional[Model] = None):
_, vname = var_key
model.cache_state(var_key)

self._maybe_create_zarr_dataset(model, var_key, name=vname)
self._create_zarr_dataset(model, var_key, name=vname)
self.zgroup[vname][:] = model._var_cache[var_key]["value"]

def consolidate(self):
Expand Down
18 changes: 14 additions & 4 deletions xsimlab/tests/test_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ def test_constructor(self, in_ds, model, zobj, tmpdir):
def test_constructor_batch(self, store_batch):
assert store_batch.batch_size == 2

def test_constructor_conflict(self, in_ds, model):
zgroup = zarr.group()
zgroup.create_dataset("profile__u", shape=(1, 1))

with pytest.raises(ValueError, match=r".*already contains.*"):
ZarrSimulationStore(in_ds, model, zobject=zgroup)

def test_write_input_xr_dataset(self, in_ds, store):
store.write_input_xr_dataset()
ds = xr.open_zarr(store.zgroup.store, chunks=None)
Expand Down Expand Up @@ -146,6 +153,9 @@ def test_write_output_vars_batch(self, store_batch, model_batch1, model_batch2):

assert_array_equal(ztest.add__offset[:], np.array([2.0, 3.0]))

# test default chunk size along batch dim
assert ztest.profile__u.chunks[0] == 1

def test_write_index_vars(self, store):
store.model.state[("init_profile", "x")] = np.array([1.0, 2.0, 3.0])

Expand Down Expand Up @@ -190,7 +200,7 @@ class P:
def test_encoding(self):
@xs.process
class P:
v1 = xs.variable(dims="x", intent="out", encoding={"chunks": (10,)})
v1 = xs.variable(dims="x", intent="out", encoding={"dtype": np.int32})
v2 = xs.on_demand(dims="x", encoding={"fill_value": 0})
v3 = xs.index(dims="x")

Expand All @@ -209,7 +219,7 @@ def _get_v2(self):
store = ZarrSimulationStore(
in_ds,
model,
encoding={"p__v2": {"fill_value": -1}, "p__v3": {"compressor": None}},
encoding={"p__v2": {"fill_value": -1}, "p__v3": {"chunks": (10,)}},
)

model.state[("p", "v1")] = [0]
Expand All @@ -218,10 +228,10 @@ def _get_v2(self):

ztest = zarr.open_group(store.zgroup.store, mode="r")

assert ztest.p__v1.chunks == (10,)
assert ztest.p__v1.dtype == np.int32
# test encoding precedence ZarrSimulationStore > model variable
assert ztest.p__v2.fill_value == -1
assert ztest.p__v3.compressor is None
assert ztest.p__v3.chunks == (10,)

def test_open_as_xr_dataset(self, store):
model = store.model
Expand Down
5 changes: 4 additions & 1 deletion xsimlab/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def test_normalize_encoding():
assert utils.normalize_encoding(None) == {}

encoding = {
"chunks": True,
"dtype": "int",
"compressor": None,
"fill_value": 0,
Expand All @@ -55,6 +54,10 @@ def test_normalize_encoding():
encoding.pop("ignored_key")
assert actual == encoding

encoding = {"chunks": True}
actual = utils.normalize_encoding(encoding, extra_keys=["chunks"])
assert actual == encoding


def test_get_batch_size():
ds = xr.Dataset({"bdim": ("bdim", [1, 2, 3])})
Expand Down
6 changes: 4 additions & 2 deletions xsimlab/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ def import_required(mod_name, error_msg):
raise RuntimeError(error_msg)


def normalize_encoding(encoding):
def normalize_encoding(encoding, extra_keys=None):
used_keys = [
"chunks",
"dtype",
"compressor",
"fill_value",
Expand All @@ -55,6 +54,9 @@ def normalize_encoding(encoding):
"object_codec",
]

if extra_keys is not None:
used_keys += extra_keys

if encoding is None:
return {}
else:
Expand Down

0 comments on commit e78ed54

Please sign in to comment.