Skip to content

Commit

Permalink
Auto-resize zarr datasets in output store (#111)
Browse files Browse the repository at this point in the history
* cache variable value

It is read several times in different methods for saving one snapshot.
This avoid perf issues, especially for on demand variables.

* maybe resize zarr datasets when saving snapshots

* add test

* update release notes

* also test skip resize and write a slice

* doc: add section with an example in user guide

* typo
  • Loading branch information
benbovy committed Mar 16, 2020
1 parent b4fbd30 commit 73c80c0
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 7 deletions.
57 changes: 57 additions & 0 deletions doc/io_storage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,60 @@ computing with Dask`_ in xarray's docs).
os.remove("model2_setup.nc")
os.remove("model2_run.nc")
shutil.rmtree("model2_run.zarr")
Advanced usage
--------------

Dynamically sized arrays
~~~~~~~~~~~~~~~~~~~~~~~~

Model variables may have one or several of their dimension(s) dynamically
resized during a simulation. When saving those variables as outputs, the
corresponding zarr datasets may be resized so that, at the end of the
simulation, all values are stored in large arrays of fixed shape and possibly
containing missing values (note: depending on chunk size, zarr doesn't need to
physically store all regions of contiguous missing values).

The example below illustrates how such variables are returned as outputs:

.. ipython::

In [2]: import numpy as np

In [3]: @xs.process
...: class Particles:
...: """Generate at each step a random number of particles
...: at random positions along an axis.
...: """
...:
...: position = xs.variable(dims='pt', intent='out')
...:
...: def initialize(self):
...: self._rng = np.random.default_rng(123)
...:
...: def run_step(self):
...: nparticles = self._rng.integers(1, 4)
...: self.position = self._rng.uniform(0, 10, size=nparticles)
...:

In [4]: model = xs.Model({'pt': Particles})

In [5]: with model:
...: in_ds = xs.create_setup(clocks={'steps': range(4)},
...: output_vars={'pt__position': 'steps'})
...: out_ds = in_ds.xsimlab.run()
...:

In [6]: out_ds.pt__position

N-dimensional arrays with missing values might not be the best format for
dealing with this kind of output data. It could still be converted into a denser
format, like for example a :class:`pandas.DataFrame` with a multi-index thanks
to the xarray Dataset or DataArray :meth:`~xarray.Dataset.stack`,
:meth:`~xarray.Dataset.dropna` and :meth:`~xarray.Dataset.to_dataframe` methods:

.. ipython::

In [7]: (out_ds.stack(particles=('steps', 'pt'))
...: .dropna('particles')
...: .to_dataframe())
2 changes: 1 addition & 1 deletion doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Enhancements
:class:`~xsimlab.RuntimeHook` class.
- Added some useful properties and methods to the ``xarray.Dataset.xsimlab``
extension (:issue:`103`).
- Save model inputs/outputs using zarr (:issue:`102`).
- Save model inputs/outputs using zarr (:issue:`102`, :issue:`111`).
- Added :class:`~xsimlab.monitoring.ProgressBar` to track simulation progress
(:issue:`104`, :issue:`110`).

Expand Down
45 changes: 39 additions & 6 deletions xsimlab/stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def _get_var_info(dataset: xr.Dataset, model: Model) -> Dict[Tuple[str, str], Di
"name": f"{p_name}__{v_name}",
"obj": v_obj,
"value_getter": _variable_value_getter(p_obj, v_name),
"value": None,
"shape": None,
}

return var_info
Expand Down Expand Up @@ -93,23 +95,31 @@ def write_input_xr_dataset(self):
ds.xsimlab._reset_output_vars(self.model, {})
ds.to_zarr(self.zgroup.store, group=self.zgroup.path, mode="a")

def _cache_value_as_array(self, var_key):
value = self.var_info[var_key]["value_getter"]()

if np.isscalar(value):
value = np.asarray(value)

self.var_info[var_key]["value"] = value

def _create_zarr_dataset(self, var_key: Tuple[str, str], name=None):
var = self.var_info[var_key]["obj"]

if name is None:
name = self.var_info[var_key]["name"]

array = self.var_info[var_key]["value_getter"]()
if np.isscalar(array):
array = np.asarray(array)

array = self.var_info[var_key]["value"]
clock = self.var_info[var_key]["clock"]

if clock is None:
shape = array.shape
else:
shape = (self.clock_sizes[clock],) + tuple(array.shape)

# init shape for dynamically sized arrays
self.var_info[var_key]["shape"] = np.asarray(shape)

chunks = True
compressor = "default"

Expand Down Expand Up @@ -152,6 +162,23 @@ def _create_zarr_dataset(self, var_key: Tuple[str, str], name=None):
# reset consolidated since metadata has just been updated
self.consolidated = False

def _maybe_resize_zarr_dataset(self, var_key: Tuple[str, str]):
# Maybe increases the length of one or more dimensions of
# the zarr array (only increases, never shrinks dimensions).

zkey = self.var_info[var_key]["name"]
zshape = self.var_info[var_key]["shape"]
array = self.var_info[var_key]["value"]

# prepend clock dim
array_shape = np.concatenate(([0], array.shape))

new_shape = np.maximum(zshape, array_shape)

if np.any(new_shape > zshape):
self.var_info[var_key]["shape"] = new_shape
self.zgroup[zkey].resize(new_shape)

def write_output_vars(self, istep: int):
save_istep = self.output_save_steps.isel(**{self.mclock_dim: istep})

Expand All @@ -163,24 +190,30 @@ def write_output_vars(self, istep: int):

clock_inc = self.clock_incs[clock]

for vk in var_keys:
self._cache_value_as_array(vk)

if clock_inc == 0:
for vk in var_keys:
self._create_zarr_dataset(vk)

for vk in var_keys:
zkey = self.var_info[vk]["name"]
array = self.var_info[vk]["value_getter"]()
array = self.var_info[vk]["value"]

if clock is None:
self.zgroup[zkey][:] = array
else:
self.zgroup[zkey][clock_inc] = array
self._maybe_resize_zarr_dataset(vk)
idx = tuple([clock_inc] + [slice(0, n) for n in array.shape])
self.zgroup[zkey][idx] = array

self.clock_incs[clock] += 1

def write_index_vars(self):
for var_key in self.model.index_vars:
_, vname = var_key
self._cache_value_as_array(var_key)
self._create_zarr_dataset(var_key, name=vname)

array = self.var_info[var_key]["value_getter"]()
Expand Down
26 changes: 26 additions & 0 deletions xsimlab/tests/test_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import xarray as xr
import zarr

import xsimlab as xs
from xsimlab.stores import ZarrOutputStore


Expand Down Expand Up @@ -89,6 +90,31 @@ def test_write_index_vars(self, in_dataset, model):

assert_array_equal(ztest.x, np.array([1.0, 2.0, 3.0]))

def test_resize_zarr_dataset(self):
@xs.process
class P:
arr = xs.variable(dims="x", intent="out")

model = xs.Model({"p": P})

in_ds = xs.create_setup(
model=model, clocks={"clock": [0, 1, 2]}, output_vars={"p__arr": "clock"},
)

_bind_store(model)
out_store = ZarrOutputStore(in_ds, model, None)

for step, size in zip([0, 1, 2], [1, 3, 2]):
model.store[("p", "arr")] = np.ones(size)
out_store.write_output_vars(step)

ztest = zarr.open_group(out_store.zgroup.store, mode="r")

expected = np.array(
[[1.0, np.nan, np.nan], [1.0, 1.0, 1.0], [1.0, 1.0, np.nan]]
)
assert_array_equal(ztest.p__arr, expected)

def test_open_as_xr_dataset(self, in_dataset, model):
_bind_store(model)
out_store = ZarrOutputStore(in_dataset, model, None)
Expand Down

0 comments on commit 73c80c0

Please sign in to comment.