Skip to content

Commit

Permalink
Validate values given as input or set in foreign processes (#74)
Browse files Browse the repository at this point in the history
* validate model inputs (+ update tests)

* runtime validation (foreign variables)

* update doc and release notes

* docstrings tweaks
  • Loading branch information
benbovy committed Dec 13, 2019
1 parent 39519a9 commit 5fbf5d0
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 25 deletions.
5 changes: 5 additions & 0 deletions doc/framework.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,17 @@ as:
needs (``intent='in'``), updates (``intent='inout'``) or computes
(``intent='out'``) a value for that variable.

It is also possible to set a default value as well as value validator(s).
See `attrs' validators`_ for more details.

.. note::

xarray-simlab does not distinguish between model parameters, input
and output variables. All can be declared using
:func:`~xsimlab.variable`.

.. _`attrs' validators`: https://www.attrs.org/en/stable/examples.html#validators

Foreign variables
~~~~~~~~~~~~~~~~~

Expand Down
4 changes: 4 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ Enhancements
has also a new ``fill_default`` parameter.
- Added static variables, i.e., variables that don't accept time-varying input
values (:issue:`73`).
- Added support for the validation of variable values (given as inputs and/or
set through foreign variables), reusing :func:`attr.validate` (:issue:`74`).
Validation is optional and is controlled by the parameter ``validate`` added
to :func:`xarray.Dataset.xsimlab.run`.

Bug fixes
~~~~~~~~~
Expand Down
62 changes: 51 additions & 11 deletions xsimlab/drivers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
from collections.abc import Mapping
import copy
from enum import Enum

import attr
import numpy as np
import xarray as xr

from .utils import variables_dict


class ValidateOption(Enum):
NOTHING = 'nothing'
INPUTS = 'inputs'
ALL = 'all'


class RuntimeContext(Mapping):
"""A mapping providing runtime information at the current time step."""

Expand Down Expand Up @@ -52,6 +60,7 @@ class BaseSimulationDriver:
simulation output store.
"""

def __init__(self, model, store, output_store):
self.model = model
self.store = store
Expand Down Expand Up @@ -125,6 +134,12 @@ def update_output_store(self, output_var_keys):

self.output_store.append(key, value)

def validate(self, p_names):
"""Run validators for all processes given in `p_names`."""

for pn in p_names:
attr.validate(self.model[pn])

def run_model(self):
"""Main function of the driver used to run a simulation (must be
implemented in sub-classes).
Expand Down Expand Up @@ -153,14 +168,17 @@ class XarraySimulationDriver(BaseSimulationDriver):
"""Simulation driver using xarray.Dataset objects as I/O.
- Perform some sanity checks on the content of the given input Dataset.
- Set model inputs from data variables or coordinates in the input Dataset.
- Set (and maybe validate) model inputs from data variables or coordinates
in the input Dataset.
- Save model outputs for given model variables, defined in specific
attributes of the input Dataset, on time frequencies given by clocks
defined as coordinates in the input Dataset.
- Get simulation results as a new xarray.Dataset object.
"""
def __init__(self, dataset, model, store, output_store):

def __init__(self, dataset, model, store, output_store,
validate=ValidateOption.INPUTS):
self.dataset = dataset
self.model = model

Expand All @@ -176,6 +194,8 @@ def __init__(self, dataset, model, store, output_store):
self.output_vars = dataset.xsimlab.output_vars
self.output_save_steps = self._get_output_save_steps()

self._validate_option = ValidateOption(validate)

def _check_missing_model_inputs(self):
"""Check if all model inputs have their corresponding variables
in the input Dataset.
Expand Down Expand Up @@ -307,15 +327,21 @@ def _get_runtime_datasets(self):
'_clock_diff': mclock_coord.diff(mclock_dim, label='lower')
}

ds_all_steps = (self.dataset.drop(list(ds_init.data_vars.keys()),
errors='ignore')
.isel({mclock_dim: slice(0, -1)})
.assign(step_data_vars))
ds_all_steps = (self.dataset
.drop(list(ds_init.data_vars.keys()), errors='ignore')
.isel({mclock_dim: slice(0, -1)})
.assign(step_data_vars))

ds_gby_steps = ds_all_steps.groupby(mclock_dim)

return ds_init, ds_gby_steps

def _maybe_validate_inputs(self, input_vars):
p_names = set([v[0] for v in input_vars])

if self._validate_option != ValidateOption.NOTHING:
self.validate(p_names)

def run_model(self):
"""Run the model and return a new Dataset with all the simulation
inputs and outputs.
Expand All @@ -328,14 +354,20 @@ def run_model(self):
"""
ds_init, ds_gby_steps = self._get_runtime_datasets()

validate_all = self._validate_option == ValidateOption.ALL

runtime_context = RuntimeContext(
sim_start=ds_init['_sim_start'].values,
sim_end=ds_init['_sim_end'].values
)

self.initialize_store(self._get_input_vars(ds_init))
in_vars = self._get_input_vars(ds_init)
self.initialize_store(in_vars)
self._maybe_validate_inputs(in_vars)

self.model.execute('initialize', runtime_context)
self.model.execute('initialize',
runtime_context,
validate=validate_all)

for step, (_, ds_step) in enumerate(ds_gby_steps):

Expand All @@ -344,11 +376,19 @@ def run_model(self):
step_end=ds_step['_clock_end'].values,
step_delta=ds_step['_clock_diff'].values)

self.update_store(self._get_input_vars(ds_step))
in_vars = self._get_input_vars(ds_step)
self.update_store(in_vars)
self._maybe_validate_inputs(in_vars)

self.model.execute('run_step',
runtime_context,
validate=validate_all)

self.model.execute('run_step', runtime_context)
self._maybe_save_output_vars(step)
self.model.execute('finalize_step', runtime_context)

self.model.execute('finalize_step',
runtime_context,
validate=validate_all)

self._maybe_save_output_vars(-1)
self.model.execute('finalize', runtime_context)
Expand Down
39 changes: 37 additions & 2 deletions xsimlab/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections import OrderedDict, defaultdict

import attr

from .variable import VarIntent, VarType
from .process import (filter_variables, get_process_cls,
get_target_variable, SimulationStage)
Expand Down Expand Up @@ -248,6 +250,28 @@ def filter_out(var):

return self._input_vars

def get_processes_to_validate(self):
"""Return a dictionary where keys are each process of the model and
values are lists of the names of other processes for which to trigger
validators right after its execution.
Useful for triggering validators of variables defined in other
processes when new values are set through foreign variables.
"""
processes_to_validate = {k: set() for k in self._processes_obj}

for p_name, p_obj in self._processes_obj.items():
out_foreign_vars = filter_variables(p_obj,
var_type=VarType.FOREIGN,
intent=VarIntent.OUT)

for var in out_foreign_vars.values():
pn, _ = p_obj.__xsimlab_store_keys__[var.name]
processes_to_validate[p_name].add(pn)

return {k: list(v) for k, v in processes_to_validate.items()}

def get_process_dependencies(self):
"""Return a dictionary where keys are each process of the model and
values are lists of the names of dependent processes (or empty
Expand Down Expand Up @@ -408,6 +432,8 @@ def __init__(self, processes):
self._input_vars = builder.get_input_variables()
self._input_vars_dict = None

self._processes_to_validate = builder.get_processes_to_validate()

self._dep_processes = builder.get_process_dependencies()
self._processes = builder.get_sorted_processes()

Expand Down Expand Up @@ -503,7 +529,7 @@ def visualize(self, show_only_variable=None, show_inputs=False,
show_inputs=show_inputs,
show_variables=show_variables)

def execute(self, stage, runtime_context):
def execute(self, stage, runtime_context, validate=False):
"""Run one stage of a simulation.
This shouldn't be called directly, except for debugging purpose.
Expand All @@ -515,12 +541,21 @@ def execute(self, stage, runtime_context):
runtime_context : dict
Dictionary containing runtime variables (e.g., time step
duration, current step).
validate : bool, optional
If True, run the variable validators in the corresponding
processes after a process (maybe) sets values through its foreign
variables (default: False). This is useful for debugging but
it may significantly impact performance.
"""
for p_obj in self._processes.values():
for p_name, p_obj in self._processes.items():
executor = p_obj.__xsimlab_executor__
executor.execute(p_obj, SimulationStage(stage), runtime_context)

if validate:
for pn in self._processes_to_validate[p_name]:
attr.validate(self._processes[pn])

def clone(self):
"""Clone the Model, i.e., create a new Model instance with the same
process classes but different instances.
Expand Down
2 changes: 2 additions & 0 deletions xsimlab/tests/fixture_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import OrderedDict
from textwrap import dedent

import attr
import numpy as np
import xarray as xr
import pytest
Expand Down Expand Up @@ -46,6 +47,7 @@ def initialize(self):
@xs.process
class Roll:
shift = xs.variable(default=2,
validator=attr.validators.instance_of(int),
description=('shift profile by a nb. of points'),
attrs={'units': 'unitless'})
u = xs.foreign(Profile, 'u')
Expand Down
6 changes: 6 additions & 0 deletions xsimlab/tests/test_drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ def test_update_output_store(self, base_driver):
base_driver.output_store[('profile', 'u')],
expected)

def test_validate(self, base_driver):
base_driver.store[('roll', 'shift')] = 2.5

with pytest.raises(TypeError, match=r".*'int'.*"):
base_driver.validate(['roll'])

def test_run_model(self, base_driver):
with pytest.raises(NotImplementedError):
base_driver.run_model()
Expand Down
41 changes: 40 additions & 1 deletion xsimlab/tests/test_xr_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
import xarray as xr
import numpy as np

import xsimlab as xs
from xsimlab import xr_accessor, create_setup
from xsimlab.xr_accessor import (as_variable_key,
_flatten_inputs, _flatten_outputs,
_maybe_get_model_from_context)

from .fixture_model import Roll


def test_filter_accessor():
ds = xr.Dataset(data_vars={'var1': ('x', [1, 2]), 'var2': ('y', [3, 4])},
Expand Down Expand Up @@ -279,7 +282,7 @@ def test_output_vars(self, model):
'out': [('roll', 'u_diff'), ('add', 'u_diff')]}
assert ds.xsimlab.output_vars == expected

def test_run(self, model, in_dataset):
def test_run_safe_mode(self, model, in_dataset):
# safe mode True: ensure model is cloned
_ = in_dataset.xsimlab.run(model=model, safe_mode=True)
assert model.profile.__xsimlab_store__ is None
Expand All @@ -288,6 +291,42 @@ def test_run(self, model, in_dataset):
_ = in_dataset.xsimlab.run(model=model, safe_mode=False)
assert model.profile.u is not None

def test_run_validate(self, model, in_dataset):
in_dataset['roll__shift'] = 2.5

# no validation -> raises within np.roll()
with pytest.raises(TypeError,
match=r"slice indices must be integers.*"):
in_dataset.xsimlab.run(model=model, validate='nothing')

# input validation at initialization -> raises within attr.validate()
with pytest.raises(TypeError, match=r".*'int'.*"):
in_dataset.xsimlab.run(model=model, validate='inputs')

in_dataset['roll__shift'] = ('clock', [1, 2.5, 1, 1, 1])

# input validation at runtime -> raises within attr.validate()
with pytest.raises(TypeError, match=r".*'int'.*"):
in_dataset.xsimlab.run(model=model, validate='inputs')

@xs.process
class SetRollShift:
shift = xs.foreign(Roll, 'shift', intent='out')

def initialize(self):
self.shift = 2.5

m = model.update_processes({'set_shift': SetRollShift})

# no validation -> raises within np.roll()
with pytest.raises(TypeError,
match=r"slice indices must be integers.*"):
in_dataset.xsimlab.run(model=m, validate='inputs')

# internal validation -> raises within attr.validate()
with pytest.raises(TypeError, match=r".*'int'.*"):
in_dataset.xsimlab.run(model=m, validate='all')

def test_run_multi(self):
ds = xr.Dataset()

Expand Down
24 changes: 15 additions & 9 deletions xsimlab/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,18 +132,19 @@ def variable(dims=(), intent='in', group=None, groups=None,
Single default value for the variable, ignored when ``intent='out'``
(default: NOTHING). A default value may also be set using a decorator.
validator : callable or list of callable, optional
Function that is called at simulation initialization (and possibly at
other times too) to check the value given for the variable.
Function that could be called before or during a simulation (or when
creating a new process instance) to check the value given
for the variable.
The function must accept three arguments:
- the process instance (access other variables)
- the variable object (access metadata)
- a passed value (check input).
- the process instance (useful for accessing the value of other
variables in that process)
- the variable object (useful for accessing the variable metadata)
- the value to be validated.
The function is expected to throw an exception in case of invalid
value.
If a ``list`` is passed, its items are treated as validators and must
all pass.
The function should throw an exception in case where an invalid value
is given.
If a ``list`` is passed, its items are all are treated as validators.
The validator can also be set using decorator notation.
static : bool, optional
If True, the value of the (input) variable must be set once
Expand All @@ -158,6 +159,11 @@ def variable(dims=(), intent='in', group=None, groups=None,
Dictionnary of additional metadata (e.g., standard_name,
units, math_symbol...).
See Also
--------
:func:`attr.ib`
:mod:`attr.validators`
"""
metadata = {'var_type': VarType.VARIABLE,
'dims': _as_dim_tuple(dims),
Expand Down

0 comments on commit 5fbf5d0

Please sign in to comment.