Skip to content

Commit

Permalink
"Stateless" process executors (#127)
Browse files Browse the repository at this point in the history
* "stateless" process execution

processes executors may receive a state dictionary as input and return a
state dictionary (only inout and out variables), in order to build
stateless dask graphs.

* weird behavior of distributed scheduler

* update tests

* update doc and docstrings

* update release notes

* fix doc build

* more comments and docstrings

* black

* rephrase
  • Loading branch information
benbovy committed Apr 12, 2020
1 parent f0e1e1a commit 763b31d
Show file tree
Hide file tree
Showing 9 changed files with 232 additions and 102 deletions.
1 change: 1 addition & 0 deletions ci/requirements/doc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies:
- pip:
- attrs==19.2.0
- dask==2.11.0
- distributed==2.11.0
- ipython==7.8.0
- matplotlib==3.0.2
- nbconvert==5.6.0
Expand Down
37 changes: 22 additions & 15 deletions doc/run_parallel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ library. There are two parallel modes:

Dask is a versatile library that provides many ways of executing tasks in
parallel (i.e., threads vs. processes, single machine vs. distributed
environments). xarray-simlab lets you choose which alternative best suits
infrastructure). xarray-simlab lets you choose which alternative best suits
your needs. Beware, however, that not all alternatives are optimal or
supported depending on your case. More details below.

Expand All @@ -32,7 +32,7 @@ This mode runs the processes in a model in parallel.

A :class:`~xsimlab.Model` object can be viewed as a Directed Acyclic Graph (DAG)
built from a collection of processes (i.e., process-decorated classes) as nodes
and their inter-dependencies as directed edges. At each simulation stage, a task
and their inter-dependencies as oriented edges. At each simulation stage, a task
graph is built from this DAG, which is then executed by one of the schedulers
available in Dask.

Expand All @@ -43,19 +43,25 @@ To activate this parallel mode, simply set ``parallel=True`` when calling
>>> in_ds.xsimlab.run(model=my_model, parallel=True)
The default Dask scheduler used here is ``"threads"`` (this is the one used by
``dask.delayed``). Other schedulers may be selected via the ``scheduler``
argument of :func:`~xarray.Dataset.xsimlab.run`. Dask also provides other ways to
select a scheduler, see `here
<https://docs.dask.org/en/latest/setup/single-machine.html>`_.
The default Dask scheduler used here is ``"threads"``. The code in the
process-decorated classes must thus be thread-safe. It should also release
CPython's Global Interpreter Lock (GIL) as much as possible in order to see a
gain in performance. For example, most Numpy functions release the GIL.

Other schedulers may be selected via the ``scheduler`` argument of
:func:`~xarray.Dataset.xsimlab.run`.

.. code:: python
>>> in_ds.xsimlab.run(model=my_model, parallel=True, scheduler="processes")
Multi-processes schedulers are not supported for this mode since simulation
active data, shared between all model components, is stored using a simple
Python dictionary.
Dask also provides other ways to select a scheduler, see `here
<https://docs.dask.org/en/latest/setup/single-machine.html>`_.

The code in the process-decorated classes must be thread-safe and should release
CPython's Global Interpreter Lock (GIL) as much as possible in order to see
a gain in performance. For example, most Numpy functions release the GIL.
Note, however, that multi-processes or distributed schedulers are not well
supported and may have very poor performance for this mode, depending on how
much simulation active data needs to be shared between the model components. See
:meth:`xsimlab.Model.execute` for more information.

The gain in performance compared to sequential execution of the model processes
will also depend on how the DAG is structured, i.e., how many processes can be
Expand Down Expand Up @@ -87,8 +93,9 @@ while calling :func:`xarray.Dataset.xsimlab.run` (see Section
>>> in_ds.xsimlab.run(model=my_model, batch_dim="batch", parallel=True, store="output.zarr")
As opposed to single-model parallelism, both multi-threads and multi-processes
Dask schedulers are supported for this embarrassingly parallel problem.
Both multi-threads and multi-processes Dask schedulers are well supported for
this embarrassingly parallel problem. Like for single-model parallelism, the
default scheduler used here is ``"threads"``.

If you use a multi-threads scheduler, the same precautions apply regarding
thread-safety and CPython's GIL.
Expand Down
2 changes: 2 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Enhancements
:func:`xarray.Dataset.xsimlab.update_vars` now accepts array-like values
with no explicit dimension label(s), in this case those labels are inferred
from model variables' metadata (:issue:`126`).
- Single-model parallelism now supports Dask's multi-processes or distributed
schedulers, although this is still limited and rarely optimal (:issue:`127`).

Bug fixes
~~~~~~~~~
Expand Down
84 changes: 66 additions & 18 deletions xsimlab/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from collections import OrderedDict, defaultdict
import copy
import time

import attr
import dask
from dask.distributed import Client

from .variable import VarIntent, VarType
from .process import (
Expand Down Expand Up @@ -721,31 +723,63 @@ def _call_hooks(self, hooks, runtime_context, stage, level, trigger):
for h in event_hooks:
h(self, Frozen(runtime_context), Frozen(self.state))

def _execute_process(self, p_obj, stage, runtime_context, hooks, validate):
def _execute_process(
self, p_obj, stage, runtime_context, hooks, validate, state=None
):
executor = p_obj.__xsimlab_executor__
p_name = p_obj.__xsimlab_name__

self._call_hooks(hooks, runtime_context, stage, "process", "pre")
executor.execute(p_obj, stage, runtime_context)
out_state = executor.execute(p_obj, stage, runtime_context, state=state)
self._call_hooks(hooks, runtime_context, stage, "process", "post")

if validate:
self.validate(self._processes_to_validate[p_name])

def _build_dask_graph(self, extra_args):
def exec_process(p_obj, deps):
self._execute_process(p_obj, *extra_args)
return p_name, out_state

dsk = {
p_name: (exec_process, self._processes[p_name], p_deps)
for p_name, p_deps in self._dep_processes.items()
}
def _build_dask_graph(self, execute_args):
"""Build a custom, 'stateless' graph of tasks (process execution) that will
be passed to a Dask scheduler.
"""

# add a dummy node so that we properly call the get func of dask scheduler
dsk["_end"] = (lambda deps: None, list(self._processes))
def exec_process(p_obj, model_state, out_states):
# update model state with output state from all dependent processes
state = {}
state.update(model_state)
for _, s in out_states:
state.update(s)

return self._execute_process(p_obj, *execute_args, state=state)

dsk = {}
for p_name, p_deps in self._dep_processes.items():
dsk[p_name] = (exec_process, self._processes[p_name], self._state, p_deps)

# add a node to gather output state from all executed processes
dsk["_gather"] = (lambda out_states: dict(out_states), list(self._processes))

return dsk

def _merge_and_update_state(self, out_states):
"""Collect, merge together and update model state from the output
states returned by all executed processes (dask graph).
"""
new_state = {}

# process order matters!
for p_name in self._processes:
new_state.update(out_states[p_name])

self._state.update(new_state)

# need to re-assign the updated state to all processes
# for access between simulation stages (e.g., save snapshots)
for p_obj in self._processes.values():
p_obj.__xsimlab_state__ = self._state

def execute(
self,
stage,
Expand Down Expand Up @@ -793,11 +827,18 @@ def execute(
in the process classes must be thread-safe. Also, it should release
the Python Global Interpreted Lock (GIL) as much as possible in order
to see a gain in performance.
- Multi-process or distributed schedulers are not supported, as
currently the model state (shared between the process classes)
is stored using a simple Python dictionary.
- Multi-process or distributed schedulers may have very poor performance,
especially when a lot of data (model state) is shared between the model
processes. The way xarray-simlab scatters/gathers this data between the
scheduler and the workers is not optimized at all. Addtionally, those
schedulers may not work well with the given ``hooks`` and/or when the
processes runtime methods rely on instance attributes that are not
explicitly declared as model variables.
"""
# TODO: issue warning if validate is True and "processes" or distributed scheduler
# is used (not supported)

if hooks is None:
hooks = {}

Expand All @@ -806,17 +847,24 @@ def execute(
dsk_get = dask.threaded.get

stage = SimulationStage(stage)
extra_args = (stage, runtime_context, hooks, validate)
execute_args = (stage, runtime_context, hooks, validate)

self._call_hooks(hooks, runtime_context, stage, "model", "pre")

if parallel:
dsk = self._build_dask_graph(extra_args)
dsk_get(dsk, "_end", scheduler=scheduler)
dsk = self._build_dask_graph(execute_args)
out_states = dsk_get(dsk, "_gather", scheduler=scheduler)

# TODO: without this -> flaky tests (don't know why)
# state is not well updated -> error when writing output vars in store
if isinstance(scheduler, Client):
time.sleep(0.001)

self._merge_and_update_state(out_states)

else:
for p_name, p_obj in self._processes.items():
self._execute_process(p_obj, *extra_args)
self._execute_process(p_obj, *execute_args)

self._call_hooks(hooks, runtime_context, stage, "model", "post")

Expand Down
106 changes: 66 additions & 40 deletions xsimlab/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,13 @@ def __init__(self, meth, args=None):

self.args = tuple(args)

def execute(self, obj, runtime_context):
def execute(self, obj, runtime_context, state=None):
if state is not None:
obj.__xsimlab_state__ = state

args = [runtime_context[k] for k in self.args]

return self.meth(obj, *args)
self.meth(obj, *args)


def runtime(meth=None, args=None):
Expand Down Expand Up @@ -330,61 +333,84 @@ class SimulationStage(Enum):
FINALIZE = "finalize"


class _ProcessExecutor:
"""Used to execute a process during simulation runtime."""
def _create_runtime_executors(cls):
runtime_executors = OrderedDict()

def __init__(self, cls):
self.cls = cls
for stage in SimulationStage:
if not has_method(cls, stage.value):
continue

meth = getattr(cls, stage.value)
executor = getattr(meth, "__xsimlab_executor__", None)

if executor is None:
nparams = len(inspect.signature(meth).parameters)

if stage == SimulationStage.RUN_STEP and nparams == 2:
# TODO: remove (depreciated)
warnings.warn(
"`run_step(self, dt)` accepting by default "
"one positional argument is depreciated and "
"will be removed in a future version of "
"xarray-simlab. Use the `@runtime` "
"decorator.",
FutureWarning,
)
args = ["step_delta"]

self.runtime_methods = OrderedDict()
elif nparams > 1:
raise TypeError(
"Process runtime methods with positional "
"parameters should be decorated with "
"`@runtime`"
)

for stage in SimulationStage:
if not has_method(self.cls, stage.value):
continue
else:
args = None

meth = getattr(self.cls, stage.value)
executor = getattr(meth, "__xsimlab_executor__", None)
executor = _RuntimeMethodExecutor(meth, args=args)

if executor is None:
nparams = len(inspect.signature(meth).parameters)
runtime_executors[stage] = executor

if stage == SimulationStage.RUN_STEP and nparams == 2:
# TODO: remove (depreciated)
warnings.warn(
"`run_step(self, dt)` accepting by default "
"one positional argument is depreciated and "
"will be removed in a future version of "
"xarray-simlab. Use the `@runtime` "
"decorator.",
FutureWarning,
)
args = ["step_delta"]
return runtime_executors

elif nparams > 1:
raise TypeError(
"Process runtime methods with positional "
"parameters should be decorated with "
"`@runtime`"
)

else:
args = None
def _get_out_variables(cls):
def filter_out(var):
var_type = var.metadata["var_type"]
var_intent = var.metadata["intent"]

if var_type != VarType.ON_DEMAND and var_intent != VarIntent.IN:
return True
else:
return False

return filter_variables(cls, func=filter_out)

executor = _RuntimeMethodExecutor(meth, args=args)

self.runtime_methods[stage] = executor
class _ProcessExecutor:
"""Used to execute a process during simulation runtime."""

def __init__(self, cls):
self.cls = cls
self.runtime_executors = _create_runtime_executors(cls)
self.out_vars = _get_out_variables(cls)

@property
def stages(self):
return [k.value for k in self.runtime_methods]
return [k.value for k in self.runtime_executors]

def execute(self, obj, stage, runtime_context):
executor = self.runtime_methods.get(stage)
def execute(self, obj, stage, runtime_context, state=None):
executor = self.runtime_executors.get(stage)

if executor is None:
return None
return {}
else:
return executor.execute(obj, runtime_context)
executor.execute(obj, runtime_context, state=state)

skeys = [obj.__xsimlab_state_keys__[k] for k in self.out_vars]
sobj = obj.__xsimlab_state__
return {k: sobj[k] for k in skeys if k in sobj}


def _process_cls_init(obj):
Expand Down
2 changes: 1 addition & 1 deletion xsimlab/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ def _importorskip(modname):
use_dask_schedulers = ["single-threaded"]
else:
# Still useful to test threads/processes (pickle issues) locally
use_dask_schedulers = ["threads", "processes", "distributed"]
use_dask_schedulers = ["threads", "processes", "distributed", "distributed-threads"]

0 comments on commit 763b31d

Please sign in to comment.