Skip to content

Commit

Permalink
Some API changes (Model runtime) (#121)
Browse files Browse the repository at this point in the history
* rename set_inputs -> update_state + move validate

* rename cache_state -> update_cache

* unrelated, minor doc improvements

* black
  • Loading branch information
benbovy committed Apr 6, 2020
1 parent 32bee7b commit 9d7cbc0
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 64 deletions.
12 changes: 7 additions & 5 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Top-level functions

create_setup

.. _api_xarray_accessor:

Dataset.xsimlab (xarray accessor)
=================================

Expand Down Expand Up @@ -104,16 +106,16 @@ Running a model
---------------

In most cases, the methods and properties listed below should not be used
directly. For running simulations, it is preferable to use the
``Dataset.xsimlab`` accessor instead. These methods might be useful though,
e.g., for debugging or for using ``Model`` objects with other interfaces.
directly. For running simulations, it is preferable to use the xarray extension
instead, see :ref:`api_xarray_accessor`. These methods might be useful though,
e.g., for using ``Model`` objects with other interfaces.

.. autosummary::
:toctree: _api_generated/

Model.state
Model.cache_state
Model.set_inputs
Model.update_state
Model.update_cache
Model.execute
Model.validate

Expand Down
2 changes: 2 additions & 0 deletions doc/inspect_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ Like :attr:`~xsimlab.Model.input_vars` and
:attr:`~xsimlab.Model.all_vars_dict` are available for all model
variables, not only inputs.

.. _inspect_model_visualize:

Visualize models as graphs
--------------------------

Expand Down
6 changes: 6 additions & 0 deletions doc/monitor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ For a full list of customization options, refer to the `Tqdm documentation`_.

.. _Tqdm documentation: https://tqdm.github.io

.. note::

Currently this progress bar doesn't support tracking the progress of batches
of simulations. If those batches are run in parallel you can
use Dask's diagnostics instead.

.. _custom_runtime_hooks:

Custom runtime hooks
Expand Down
35 changes: 18 additions & 17 deletions doc/run_parallel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ library. There are two parallel modes:
Single-model parallelism
------------------------

This mode runs each process in a model in parallel.
This mode runs the processes in a model in parallel.

A :class:`~xsimlab.Model` object can be viewed as a Directed Acyclic Graph (DAG)
built from a collection of processes (i.e., process-decorated classes). At each
simulation stage, a task graph is built from this graph, which is then executed
by one of the schedulers available in Dask.
built from a collection of processes (i.e., process-decorated classes) as nodes
and their inter-dependencies as directed edges. At each simulation stage, a task
graph is built from this DAG, which is then executed by one of the schedulers
available in Dask.

To activate this parallel mode, simply set ``parallel=True`` when calling
:func:`xarray.Dataset.xsimlab.run`:
Expand All @@ -44,22 +45,22 @@ To activate this parallel mode, simply set ``parallel=True`` when calling
The default Dask scheduler used here is ``"threads"`` (this is the one used by
``dask.delayed``). Other schedulers may be selected via the ``scheduler``
argument of :func:`~xarray.Dataset.xsimlab.run`. Dask also supports other ways to
argument of :func:`~xarray.Dataset.xsimlab.run`. Dask also provides other ways to
select a scheduler, see `here
<https://docs.dask.org/en/latest/setup/single-machine.html>`_.

Note, however, that multi-processes schedulers are not supported for this mode,
since simulation active data (shared between all model components) is stored
using a simple Python dictionary.
Multi-processes schedulers are not supported for this mode since simulation
active data, shared between all model components, is stored using a simple
Python dictionary.

Note also that the code in the process-decorated classes must be thread-safe
and should release the CPython's Global Interpreter Lock (GIL) as much as
possible in order to see a gain in performance. For example, most Numpy
functions release the GIL.
The code in the process-decorated classes must be thread-safe and should release
CPython's Global Interpreter Lock (GIL) as much as possible in order to see
a gain in performance. For example, most Numpy functions release the GIL.

The gain in performance compared to sequential execution of the model processes
will also depend on how the DAG is structured, i.e., how many processes can be
executed in parallel.
executed in parallel. Visualizing the DAG helps a lot, see Section
:ref:`inspect_model_visualize`.

.. _run_parallel_multi:

Expand All @@ -84,7 +85,7 @@ while calling :func:`xarray.Dataset.xsimlab.run` (see Section

.. code:: python
>>> in_ds.xsimlab.run(model=my_model, batch_dim="batch", parallel=True)
>>> in_ds.xsimlab.run(model=my_model, batch_dim="batch", parallel=True, store="output.zarr")
As opposed to single-model parallelism, both multi-threads and multi-processes
Dask schedulers are supported for this embarrassingly parallel problem.
Expand All @@ -102,8 +103,8 @@ If you use a multi-processes scheduler, beware of the following:
- By default, the chunk size of Zarr datasets along the batch dimension is equal
to 1 in order to prevent race conditions during parallel writes. This might
not be optimal for further post-processing, though. It is possible to override
this default and set larger chunk sizes (via the ``encoding`` parameter of
:func:`~xarray.Dataset.xsimlab.run`), but then you should also use one of the
Zarr's synchronizers (either :class:`zarr.sync.ThreadSynchronizer` or
this default value and set larger chunk sizes via the ``encoding`` parameter
of :func:`~xarray.Dataset.xsimlab.run`, but then you should also use one of
the Zarr's synchronizers (either :class:`zarr.sync.ThreadSynchronizer` or
:class:`zarr.sync.ProcessSynchronizer`) to ensure that all output values will
be properly saved.
38 changes: 16 additions & 22 deletions xsimlab/drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,6 @@ def _get_input_vars(self, dataset):

return input_vars

def _maybe_validate_inputs(self, model, input_vars):
p_names = set([v[0] for v in input_vars])

if self._validate_option is not None:
model.validate(p_names)

def get_results(self):
"""Get simulation results as a xarray.Dataset loaded from
the zarr store.
Expand Down Expand Up @@ -368,6 +362,14 @@ def _run_one_model(self, dataset, model, batch=-1, parallel=False):
ds_init, ds_gby_steps = _generate_runtime_datasets(dataset)

validate_all = self._validate_option is ValidateOption.ALL
validate_inputs = validate_all or self._validate_option is ValidateOption.INPUTS

execute_kwargs = {
"hooks": self.hooks,
"validate": validate_all,
"parallel": parallel,
"scheduler": self.scheduler,
}

rt_context = RuntimeContext(
batch_size=self.batch_size,
Expand All @@ -377,17 +379,9 @@ def _run_one_model(self, dataset, model, batch=-1, parallel=False):
sim_end=ds_init["_sim_end"].values,
)

in_vars = self._get_input_vars(ds_init)
model.set_inputs(in_vars, ignore_static=True)
self._maybe_validate_inputs(model, in_vars)

execute_kwargs = {
"hooks": self.hooks,
"validate": validate_all,
"parallel": parallel,
"scheduler": self.scheduler,
}

model.update_state(
self._get_input_vars(ds_init), validate=validate_inputs, ignore_static=True
)
model.execute("initialize", rt_context, **execute_kwargs)

for step, (_, ds_step) in enumerate(ds_gby_steps):
Expand All @@ -398,11 +392,11 @@ def _run_one_model(self, dataset, model, batch=-1, parallel=False):
step_end=ds_step["_clock_end"].values,
step_delta=ds_step["_clock_diff"].values,
)

in_vars = self._get_input_vars(ds_step)
model.set_inputs(in_vars, ignore_static=False)
self._maybe_validate_inputs(model, in_vars)

model.update_state(
self._get_input_vars(ds_step),
validate=validate_inputs,
ignore_static=False,
)
model.execute("run_step", rt_context, **execute_kwargs)

self.store.write_output_vars(batch, step, model=model)
Expand Down
30 changes: 19 additions & 11 deletions xsimlab/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,29 +615,34 @@ def state(self):
"""
return self._state

def set_inputs(self, input_data, ignore_static=False, ignore_invalid_keys=True):
"""Set or update input variable values in the model's state.
def update_state(
self, input_vars, validate=True, ignore_static=False, ignore_invalid_keys=True
):
"""Update the model's state (only input variables) with new values.
Prior to update the model's state, first convert the values for model
variables that have a converter, otherwise copy the values.
Parameters
----------
input_data : dict_like
A mapping where keys are in the form of a
``('process_name', 'var_name')`` tuple and values are
input_vars : dict_like
A mapping where keys are in the form of
``('process_name', 'var_name')`` tuples and values are
the input values to set in the model state.
validate : bool, optional
If True (default), run the variable validators after setting the
new values.
ignore_static : bool, optional
If True, sets the values even for static variables. Otherwise
(default), raises a ``ValueError`` in order to prevent updating
values of static variables.
ignore_invalid_keys : bool, optional
If True (default), ignores keys in ``input_data`` that do not
If True (default), ignores keys in ``input_vars`` that do not
correspond to input variables in the model. Otherwise, raises
a ``KeyError``.
"""
for key, value in input_data.items():
for key, value in input_vars.items():

if key not in self.input_vars:
if ignore_invalid_keys:
Expand All @@ -655,12 +660,15 @@ def set_inputs(self, input_data, ignore_static=False, ignore_invalid_keys=True):
else:
self._state[key] = copy.copy(value)

def cache_state(self, var_key):
"""Explicitly cache the current value in state for a given model
variable.
if validate:
p_names = set([pn for pn, _ in input_vars if pn in self._processes])
self.validate(p_names)

def update_cache(self, var_key):
"""Update the model's cache for a given model variable.
This is generally not really needed, except for on demand variables
where this may optimize multiple accesses to the variable value between
where this might optimize multiple accesses to the variable value between
two simulation stages.
No copy is performed.
Expand Down
4 changes: 2 additions & 2 deletions xsimlab/stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def write_output_vars(self, batch: int, step: int, model: Optional[Model] = None
clock_inc = self.clock_incs[clock][batch]

for vk in var_keys:
model.cache_state(vk)
model.update_cache(vk)

if clock_inc == 0:
for vk in var_keys:
Expand Down Expand Up @@ -322,7 +322,7 @@ def write_index_vars(self, model: Optional[Model] = None):

for var_key in model.index_vars:
_, vname = var_key
model.cache_state(var_key)
model.update_cache(var_key)

self._create_zarr_dataset(model, var_key, name=vname)
self.zgroup[vname][:] = model._var_cache[var_key]["value"]
Expand Down
22 changes: 15 additions & 7 deletions xsimlab/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def test_input_vars_dict(self, model):
)
assert "n_points" in model.input_vars_dict["init_profile"]

def test_set_inputs(self, model):
def test_update_state(self, model):
arr = np.array([1, 2, 3, 4])

input_vars = {
Expand All @@ -226,7 +226,7 @@ def test_set_inputs(self, model):
("not-a-model", "input"): 0,
}

model.set_inputs(input_vars, ignore_static=True, ignore_invalid_keys=True)
model.update_state(input_vars, ignore_static=True, ignore_invalid_keys=True)

# test converted value
assert model.state[("init_profile", "n_points")] == 10
Expand All @@ -239,22 +239,30 @@ def test_set_inputs(self, model):
# test invalid key ignored
assert ("not-a-model", "input") not in model.state

# test validate
with pytest.raises(TypeError, match=r".*'int'.*"):
model.update_state({("roll", "shift"): 2.5})

# test errors
with pytest.raises(ValueError, match=r".* static variable .*"):
model.set_inputs(input_vars, ignore_static=False, ignore_invalid_keys=True)
model.update_state(
input_vars, ignore_static=False, ignore_invalid_keys=True
)

with pytest.raises(KeyError, match=r".* not a valid input variable .*"):
model.set_inputs(input_vars, ignore_static=True, ignore_invalid_keys=False)
model.update_state(
input_vars, ignore_static=True, ignore_invalid_keys=False
)

def test_cache_state(self, model):
def test_update_cache(self, model):
model.state[("init_profile", "n_points")] = 10
model.cache_state(("init_profile", "n_points"))
model.update_cache(("init_profile", "n_points"))

assert model._var_cache[("init_profile", "n_points")]["value"] == 10

# test on demand variables
model.state[("add", "offset")] = 1
model.cache_state(("add", "u_diff"))
model.update_cache(("add", "u_diff"))

assert model._var_cache[("add", "u_diff")]["value"] == 1

Expand Down

0 comments on commit 9d7cbc0

Please sign in to comment.