Skip to content

Commit

Permalink
Add encoding options for storing model variables (#113)
Browse files Browse the repository at this point in the history
* add encoding parameter to variable functions

* add encoding param to run + implement encoding

* fix tuple/list values written in zarr store

* add tests

* black

* update release notes

* doc: add encoding options subsection with examples

* test encoding option supplied for run() only
  • Loading branch information
benbovy authored Mar 18, 2020
1 parent 928320d commit c59ff44
Show file tree
Hide file tree
Showing 12 changed files with 208 additions and 32 deletions.
49 changes: 49 additions & 0 deletions doc/io_storage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,52 @@ to the xarray Dataset or DataArray :meth:`~xarray.Dataset.stack`,
In [7]: (out_ds.stack(particles=('steps', 'pt'))
...: .dropna('particles')
...: .to_dataframe())

Encoding options
~~~~~~~~~~~~~~~~

It is possible to control via some encoding options how Zarr stores simulation
data.

Those options can be set for variables declared in process classes. See the
``encoding`` parameter of :func:`~xsimlab.variable` for all available options.
In the example below we specify a custom fill value for the ``position``
variable, which will be used to replace missing values:

.. ipython::

In [4]: @xs.process
...: class Particles:
...: position = xs.variable(dims='pt', intent='out',
...: encoding={'fill_value': -1.0})
...:
...: def initialize(self):
...: self._rng = np.random.default_rng(123)
...:
...: def run_step(self):
...: nparticles = self._rng.integers(1, 4)
...: self.position = self._rng.uniform(0, 10, size=nparticles)
...:

In [5]: model = xs.Model({'pt': Particles})

In [6]: with model:
...: in_ds = xs.create_setup(clocks={'steps': range(4)},
...: output_vars={'pt__position': 'steps'})
...: out_ds = in_ds.xsimlab.run()
...:

In [7]: out_ds.pt__position

Encoding options may also be set or overridden when calling
:func:`~xarray.Dataset.xsimlab.run`, e.g.,

.. ipython::

In [8]: out_ds = in_ds.xsimlab.run(
...: model=model,
...: encoding={'pt__position': {'fill_value': -10.0}}
...: )
...:

In [9]: out_ds.pt__position
3 changes: 2 additions & 1 deletion doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ Enhancements
:class:`~xsimlab.RuntimeHook` class.
- Added some useful properties and methods to the ``xarray.Dataset.xsimlab``
extension (:issue:`103`).
- Save model inputs/outputs using zarr (:issue:`102`, :issue:`111`).
- Save model inputs/outputs using zarr (:issue:`102`, :issue:`111`,
:issue:`113`).
- Added :class:`~xsimlab.monitoring.ProgressBar` to track simulation progress
(:issue:`104`, :issue:`110`).

Expand Down
3 changes: 2 additions & 1 deletion xsimlab/drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def __init__(
model,
state,
store,
encoding,
check_dims=CheckDimsOption.STRICT,
validate=ValidateOption.INPUTS,
hooks=None,
Expand Down Expand Up @@ -204,7 +205,7 @@ def __init__(
hooks = set(hooks) | RuntimeHook.active
self._hooks = group_hooks(flatten_hooks(hooks))

self.store = ZarrSimulationStore(dataset, model, store)
self.store = ZarrSimulationStore(dataset, model, store, encoding)

def _check_missing_model_inputs(self):
"""Check if all model inputs have their corresponding variables
Expand Down
44 changes: 28 additions & 16 deletions xsimlab/stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import xarray as xr
import zarr

from xsimlab import Model
from xsimlab.process import variables_dict
from . import Model
from .process import variables_dict
from .utils import normalize_encoding


_DIMENSION_KEY = "_ARRAY_DIMENSIONS"
Expand All @@ -19,24 +20,31 @@ def value_getter():
return value_getter


def _get_var_info(dataset: xr.Dataset, model: Model) -> Dict[Tuple[str, str], Dict]:
def _get_var_info(
dataset: xr.Dataset, model: Model, encoding: Dict[str, Dict[str, Any]]
) -> Dict[Tuple[str, str], Dict]:
var_info = {}

var_clocks = {k: v for k, v in dataset.xsimlab.output_vars.items()}
var_clocks.update({vk: None for vk in model.index_vars})

for var_key, clock in var_clocks.items():
p_name, v_name = var_key
v_name_str = f"{p_name}__{v_name}"
p_obj = model[p_name]
v_obj = variables_dict(type(p_obj))[v_name]

v_encoding = v_obj.metadata["encoding"]
v_encoding.update(normalize_encoding(encoding.get(v_name_str)))

var_info[var_key] = {
"clock": clock,
"name": f"{p_name}__{v_name}",
"name": v_name_str,
"obj": v_obj,
"value_getter": _variable_value_getter(p_obj, v_name),
"value": None,
"shape": None,
"encoding": v_encoding,
}

return var_info
Expand All @@ -60,6 +68,7 @@ def __init__(
dataset: xr.Dataset,
model: Model,
zobject: Union[zarr.Group, MutableMapping, str, None],
encoding: Union[Dict[str, Dict[str, Any]], None],
):
self.dataset = dataset
self.model = model
Expand All @@ -78,7 +87,10 @@ def __init__(
self.output_vars = dataset.xsimlab.output_vars_by_clock
self.output_save_steps = dataset.xsimlab.get_output_save_steps()

self.var_info = _get_var_info(dataset, model)
if encoding is None:
encoding = {}

self.var_info = _get_var_info(dataset, model, encoding)

self.mclock_dim = dataset.xsimlab.master_clock_dim
self.clock_sizes = dataset.xsimlab.clock_sizes
Expand All @@ -98,7 +110,7 @@ def write_input_xr_dataset(self):
def _cache_value_as_array(self, var_key):
value = self.var_info[var_key]["value_getter"]()

if np.isscalar(value):
if np.isscalar(value) or isinstance(value, (list, tuple)):
value = np.asarray(value)

self.var_info[var_key]["value"] = value
Expand All @@ -120,22 +132,22 @@ def _create_zarr_dataset(self, var_key: Tuple[str, str], name=None):
# init shape for dynamically sized arrays
self.var_info[var_key]["shape"] = np.asarray(shape)

chunks = True
compressor = "default"
zkwargs = {
"shape": shape,
"chunks": True,
"dtype": array.dtype,
"compressor": "default",
"fill_value": default_fill_value_from_dtype(array.dtype),
}

zkwargs.update(self.var_info[var_key]["encoding"])

# TODO: more performance assessment
# if self.in_memory:
# chunks = False
# compressor = None

zdataset = self.zgroup.create_dataset(
name,
shape=shape,
chunks=chunks,
dtype=array.dtype,
compressor=compressor,
fill_value=default_fill_value_from_dtype(array.dtype),
)
zdataset = self.zgroup.create_dataset(name, **zkwargs)

# add dimension labels and variable attributes as metadata
dim_labels = None
Expand Down
1 change: 1 addition & 0 deletions xsimlab/tests/fixture_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def in_var_details():
- groups : ()
- static : False
- attrs : {}
- encoding : {}
"""
)

Expand Down
8 changes: 4 additions & 4 deletions xsimlab/tests/test_drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def base_driver(model):
@pytest.fixture
def xarray_driver(in_dataset, model):
state = {}
return XarraySimulationDriver(in_dataset, model, state, None)
return XarraySimulationDriver(in_dataset, model, state, None, None)


def test_runtime_context():
Expand Down Expand Up @@ -92,12 +92,12 @@ def test_constructor(self, in_dataset, model):

invalid_ds = in_dataset.drop("clock")
with pytest.raises(ValueError) as excinfo:
XarraySimulationDriver(invalid_ds, model, state, None)
XarraySimulationDriver(invalid_ds, model, state, None, None)
assert "Missing master clock" in str(excinfo.value)

invalid_ds = in_dataset.drop("init_profile__n_points")
with pytest.raises(KeyError) as excinfo:
XarraySimulationDriver(invalid_ds, model, state, None)
XarraySimulationDriver(invalid_ds, model, state, None, None)
assert "Missing variables" in str(excinfo.value)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -142,7 +142,7 @@ def run_step(self, arg):

m = model.update_processes({"p": P})

driver = XarraySimulationDriver(in_dataset, m, {}, None)
driver = XarraySimulationDriver(in_dataset, m, {}, None, None)

with pytest.raises(KeyError, match="'not_a_runtime_arg'"):
driver.run_model()
2 changes: 2 additions & 0 deletions xsimlab/tests/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class Dummy_placeholder:
- groups : ()
- static : False
- attrs : {}
- encoding : {}
var2 : object
(no description given)
Expand All @@ -90,6 +91,7 @@ class Dummy_placeholder:
- groups : ()
- static : False
- attrs : {}
- encoding : {}
"""

Expand Down
54 changes: 46 additions & 8 deletions xsimlab/tests/test_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ class TestZarrSimulationStore:
"zobject", [None, mkdtemp(), zarr.MemoryStore(), zarr.group()]
)
def test_constructor(self, in_dataset, model, zobject):
out_store = ZarrSimulationStore(in_dataset, model, zobject)
out_store = ZarrSimulationStore(in_dataset, model, zobject, None)

assert out_store.zgroup.store is not None

def test_write_input_xr_dataset(self, in_dataset, model):
out_store = ZarrSimulationStore(in_dataset, model, None)
out_store = ZarrSimulationStore(in_dataset, model, None, None)

out_store.write_input_xr_dataset()
ds = xr.open_zarr(out_store.zgroup.store, chunks=None)
Expand All @@ -40,7 +40,7 @@ def test_write_input_xr_dataset(self, in_dataset, model):

def test_write_output_vars(self, in_dataset, model):
_bind_state(model)
out_store = ZarrSimulationStore(in_dataset, model, None)
out_store = ZarrSimulationStore(in_dataset, model, None, None)

model.state[("profile", "u")] = np.array([1.0, 2.0, 3.0])
model.state[("roll", "u_diff")] = np.array([-1.0, 1.0, 0.0])
Expand Down Expand Up @@ -70,7 +70,7 @@ def test_write_output_vars(self, in_dataset, model):

def test_write_output_vars_error(self, in_dataset, model):
_bind_state(model)
out_store = ZarrSimulationStore(in_dataset, model, None)
out_store = ZarrSimulationStore(in_dataset, model, None, None)

model.state[("profile", "u")] = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
model.state[("roll", "u_diff")] = np.array([-1.0, 1.0, 0.0])
Expand All @@ -81,7 +81,7 @@ def test_write_output_vars_error(self, in_dataset, model):

def test_write_index_vars(self, in_dataset, model):
_bind_state(model)
out_store = ZarrSimulationStore(in_dataset, model, None)
out_store = ZarrSimulationStore(in_dataset, model, None, None)

model.state[("init_profile", "x")] = np.array([1.0, 2.0, 3.0])

Expand All @@ -102,7 +102,7 @@ class P:
)

_bind_state(model)
out_store = ZarrSimulationStore(in_ds, model, None)
out_store = ZarrSimulationStore(in_ds, model, None, None)

for step, size in zip([0, 1, 2], [1, 3, 2]):
model.state[("p", "arr")] = np.ones(size)
Expand All @@ -115,9 +115,47 @@ class P:
)
assert_array_equal(ztest.p__arr, expected)

def test_encoding(self):
@xs.process
class P:
v1 = xs.variable(dims="x", intent="out", encoding={"chunks": (10,)})
v2 = xs.on_demand(dims="x", encoding={"fill_value": 0})
v3 = xs.index(dims="x")

@v2.compute
def _get_v2(self):
return [0]

model = xs.Model({"p": P})

in_ds = xs.create_setup(
model=model,
clocks={"clock": [0]},
output_vars={"p__v1": None, "p__v2": None, "p__v3": None},
)

_bind_state(model)
out_store = ZarrSimulationStore(
in_ds,
model,
None,
{"p__v2": {"fill_value": -1}, "p__v3": {"compressor": None}},
)

model.state[("p", "v1")] = [0]
model.state[("p", "v3")] = [0]
out_store.write_output_vars(-1)

ztest = zarr.open_group(out_store.zgroup.store, mode="r")

assert ztest.p__v1.chunks == (10,)
# test encoding precedence ZarrSimulationStore > model variable
assert ztest.p__v2.fill_value == -1
assert ztest.p__v3.compressor is None

def test_open_as_xr_dataset(self, in_dataset, model):
_bind_state(model)
out_store = ZarrSimulationStore(in_dataset, model, None)
out_store = ZarrSimulationStore(in_dataset, model, None, None)

model.state[("profile", "u")] = np.array([1.0, 2.0, 3.0])
model.state[("roll", "u_diff")] = np.array([-1.0, 1.0, 0.0])
Expand All @@ -130,7 +168,7 @@ def test_open_as_xr_dataset(self, in_dataset, model):

def test_open_as_xr_dataset_chunks(self, in_dataset, model):
_bind_state(model)
out_store = ZarrSimulationStore(in_dataset, model, mkdtemp())
out_store = ZarrSimulationStore(in_dataset, model, mkdtemp(), None)

model.state[("profile", "u")] = np.array([1.0, 2.0, 3.0])
model.state[("roll", "u_diff")] = np.array([-1.0, 1.0, 0.0])
Expand Down
19 changes: 19 additions & 0 deletions xsimlab/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,25 @@ def test_import_required():
assert err_msg in str(excinfo.value)


def test_normalize_encoding():
assert utils.normalize_encoding(None) == {}

encoding = {
"chunks": True,
"dtype": "int",
"compressor": None,
"fill_value": 0,
"order": "C",
"filters": None,
"object_codec": None,
"ignored_key": None,
}

actual = utils.normalize_encoding(encoding)
encoding.pop("ignored_key")
assert actual == encoding


class TestAttrMapping:
@pytest.fixture
def attr_mapping(self):
Expand Down
Loading

0 comments on commit c59ff44

Please sign in to comment.