Skip to content

Commit

Permalink
Check (transpose) the dimensions of input variables (#76)
Browse files Browse the repository at this point in the history
* check/transpose input variables dims (+ fix tests)

* fix check_dims=None

* update tests

* maybe transpose input/output variable

* update doc / release notes
  • Loading branch information
benbovy committed Dec 16, 2019
1 parent 45c2f71 commit 260ed14
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 25 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ Enhancements
set through foreign variables), reusing :func:`attr.validate` (:issue:`74`).
Validation is optional and is controlled by the parameter ``validate`` added
to :func:`xarray.Dataset.xsimlab.run`.
- Check or automatically transpose the dimensions of the variables given in
input xarray Datasets to match those defined in model variables (:issue:`76`).
This is optional and controlled by the parameter ``check_dims`` added
to :func:`xarray.Dataset.xsimlab.run`.

Bug fixes
~~~~~~~~~
Expand Down
71 changes: 63 additions & 8 deletions xsimlab/drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ class ValidateOption(Enum):
ALL = 'all'


class CheckDimsOption(Enum):
STRICT = 'strict'
TRANSPOSE = 'transpose'


class RuntimeContext(Mapping):
"""A mapping providing runtime information at the current time step."""

Expand Down Expand Up @@ -178,6 +183,7 @@ class XarraySimulationDriver(BaseSimulationDriver):
"""

def __init__(self, dataset, model, store, output_store,
check_dims=CheckDimsOption.STRICT,
validate=ValidateOption.INPUTS):
self.dataset = dataset
self.model = model
Expand All @@ -194,6 +200,12 @@ def __init__(self, dataset, model, store, output_store,
self.output_vars = dataset.xsimlab.output_vars
self.output_save_steps = self._get_output_save_steps()

if check_dims is not None:
check_dims = CheckDimsOption(check_dims)
self._check_dims_option = check_dims

self._transposed_vars = {}

self._validate_option = ValidateOption(validate)

def _check_missing_model_inputs(self):
Expand Down Expand Up @@ -234,21 +246,47 @@ def _get_output_save_steps(self):

return save_steps

def _maybe_transpose(self, xr_var, p_name, var_name):
var = variables_dict(self.model[p_name].__class__)[var_name]

dims = var.metadata['dims']
dims_set = {frozenset(d): d for d in dims}
xr_dims_set = frozenset(xr_var.dims)

strict = self._check_dims_option is CheckDimsOption.STRICT
transpose = self._check_dims_option is CheckDimsOption.TRANSPOSE

if transpose and xr_dims_set in dims_set:
self._transposed_vars[(p_name, var_name)] = xr_var.dims
xr_var = xr_var.transpose(*dims_set[xr_dims_set])

if (strict or transpose) and xr_var.dims not in dims:
raise ValueError("Invalid dimension(s) for variable '{}__{}': "
"found {!r}, must be one of {}"
.format(p_name, var_name, xr_var.dims,
",".join([str(d) for d in dims])))

return xr_var

def _get_input_vars(self, dataset):
input_vars = {}

for p_name, var_name in self.model.input_vars:
xr_var_name = p_name + '__' + var_name
xr_var = dataset.get(xr_var_name)

if xr_var is not None:
data = xr_var.data
if xr_var is None:
continue

xr_var = self._maybe_transpose(xr_var, p_name, var_name)

data = xr_var.data

if data.ndim == 0:
# convert array to scalar
data = data.item()
if data.ndim == 0:
# convert array to scalar
data = data.item()

input_vars[(p_name, var_name)] = data
input_vars[(p_name, var_name)] = data

return input_vars

Expand All @@ -264,7 +302,12 @@ def _maybe_save_output_vars(self, istep):
self.update_output_store(var_keys)

def _to_xr_variable(self, key, clock):
"""Convert an output variable to a xarray.Variable object."""
"""Convert an output variable to a xarray.Variable object.
Maybe transpose the variable to match the dimension order
of the variable given in the input dataset (if any).
"""
p_name, var_name = key
p_obj = self.model[p_name]
var = variables_dict(type(p_obj))[var_name]
Expand All @@ -274,14 +317,26 @@ def _to_xr_variable(self, key, clock):
data = data[0]

dims = _get_dims_from_variable(data, var, clock)
original_dims = self._transposed_vars.get(key)

if clock is not None:
dims = (clock,) + dims

if original_dims is not None:
original_dims = (clock,) + original_dims

attrs = var.metadata['attrs'].copy()
if var.metadata['description']:
attrs['description'] = var.metadata['description']

return xr.Variable(dims, data, attrs=attrs)
xr_var = xr.Variable(dims, data, attrs=attrs)

if original_dims is not None:
# TODO: use ellipsis for clock dim in transpose
# added in xarray 0.14.1 (too recent)
xr_var = xr_var.transpose(*original_dims)

return xr_var

def _get_output_dataset(self):
"""Return a new dataset as a copy of the input dataset updated with
Expand Down
3 changes: 2 additions & 1 deletion xsimlab/tests/fixture_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def run_step(self, dt):

@xs.process
class AddOnDemand:
offset = xs.variable(description='offset added to profile u')
offset = xs.variable(dims=[(), 'x'],
description='offset added to profile u')
u_diff = xs.on_demand(groups='diff')

@u_diff.compute
Expand Down
29 changes: 14 additions & 15 deletions xsimlab/tests/test_drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,17 @@ def test_output_save_steps(self, xarray_driver):
for k in expected:
assert_array_equal(xarray_driver.output_save_steps[k], expected[k])

@pytest.mark.parametrize('var_key,is_scalar', [
(('init_profile', 'n_points'), True),
(('add', 'offset'), False)
@pytest.mark.parametrize('value,is_scalar', [
(1, True),
(('x', [1, 1, 1, 1, 1]), False)
])
def test_get_input_vars(self, in_dataset, xarray_driver,
var_key, is_scalar):
def test_get_input_vars_scalar(self, in_dataset, xarray_driver,
value, is_scalar):
in_dataset['add__offset'] = value
in_vars = xarray_driver._get_input_vars(in_dataset)

actual = in_vars[var_key]
expected = in_dataset['__'.join(var_key)].data
actual = in_vars[('add', 'offset')]
expected = in_dataset['add__offset'].data

if is_scalar:
assert actual == expected
Expand All @@ -143,18 +144,16 @@ def test_run_model(self, in_dataset, out_dataset, xarray_driver):

def test_runtime_context(self, in_dataset, model):
@xs.process
class BadProcess:
class P:

@xs.runtime(args='bad')
def run_step(self, bad):
@xs.runtime(args='not_a_runtime_arg')
def run_step(self, arg):
pass

bad_model = model.update_processes({'bad': BadProcess})
m = model.update_processes({'p': P})

driver = XarraySimulationDriver(in_dataset, bad_model,
driver = XarraySimulationDriver(in_dataset, m,
{}, InMemoryOutputStore())

with pytest.raises(KeyError) as excinfo:
with pytest.raises(KeyError, match="'not_a_runtime_arg'"):
driver.run_model()

assert str(excinfo.value) == "'bad'"
43 changes: 43 additions & 0 deletions xsimlab/tests/test_xr_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,49 @@ def test_run_safe_mode(self, model, in_dataset):
_ = in_dataset.xsimlab.run(model=model, safe_mode=False)
assert model.profile.u is not None

def test_run_check_dims(self):
@xs.process
class P:
var = xs.variable(dims=['x', ('x', 'y')])

m = xs.Model({'p': P})

arr = np.array([[1, 2], [3, 4]])

in_ds = xs.create_setup(
model=m,
clocks={'clock': [1, 2]},
input_vars={
'p__var': (('y', 'x'), arr)
},
output_vars={None: ['p__var']}
)

out_ds = in_ds.xsimlab.run(model=m, check_dims=None)
actual = out_ds.p__var.values
np.testing.assert_array_equal(actual, arr)

with pytest.raises(ValueError, match=r"Invalid dimension.*"):
in_ds.xsimlab.run(model=m, check_dims='strict')

out_ds = in_ds.xsimlab.run(model=m, check_dims='transpose',
safe_mode=False)
actual = out_ds.p__var.values
np.testing.assert_array_equal(actual, arr)
np.testing.assert_array_equal(m.p.var, arr.transpose())

in_ds2 = in_ds.xsimlab.update_vars(
model=m,
output_vars={'clock': ['p__var']}
)
# TODO: fix update output vars time-independet -> dependent
# currently need the workaround below
in_ds2.attrs = {}

out_ds = in_ds2.xsimlab.run(model=m, check_dims='transpose')
actual = out_ds.p__var.isel(clock=-1).values
np.testing.assert_array_equal(actual, arr)

def test_run_validate(self, model, in_dataset):
in_dataset['roll__shift'] = 2.5

Expand Down
16 changes: 15 additions & 1 deletion xsimlab/xr_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,13 +492,26 @@ def filter_vars(self, model=None):

return ds

def run(self, model=None, validate='inputs', safe_mode=True):
def run(self, model=None, check_dims='strict', validate='inputs',
safe_mode=True):
"""Run the model.
Parameters
----------
model : :class:`xsimlab.Model` object, optional
Reference model. If None, tries to get model from context.
check_dims : str, optional
Check the dimension(s) of each input variable given in Dataset.
It may be one of the following options:
- 'strict': the dimension labels must exactly correspond to
(one of) the label sequences defined by their respective model
variables (default)
- 'transpose': input variables might be transposed in order to
match (one of) the label sequences defined by their respective
model variables
If None is given, no check is performed.
validate : {'nothing', 'inputs', 'all'}, optional
Define what will be validated using the variable's validators
defined in ``model``'s processes (if any). It should be one of the
Expand Down Expand Up @@ -531,6 +544,7 @@ def run(self, model=None, validate='inputs', safe_mode=True):
output_store = InMemoryOutputStore()

driver = XarraySimulationDriver(self._ds, model, store, output_store,
check_dims=check_dims,
validate=validate)

return driver.run_model()
Expand Down

0 comments on commit 260ed14

Please sign in to comment.