Skip to content

Commit

Permalink
Clean-up and add some properties/methods to Dataset.xsimlab (#103)
Browse files Browse the repository at this point in the history
* clean-up and add useful properties/methods

* add tests

* doc: add API entries

* black

* update release notes

* add missing API entry in docs
  • Loading branch information
benbovy committed Feb 28, 2020
1 parent 3b289a8 commit d5db973
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 38 deletions.
5 changes: 5 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,12 @@ properties listed below. Proper use of this accessor should be like:
:template: autosummary/accessor_attribute.rst

Dataset.xsimlab.clock_coords
Dataset.xsimlab.clock_sizes
Dataset.xsimlab.master_clock_dim
Dataset.xsimlab.master_clock_coord
Dataset.xsimlab.nsteps
Dataset.xsimlab.output_vars
Dataset.xsimlab.output_vars_by_clock

**Methods**

Expand All @@ -53,6 +57,7 @@ properties listed below. Proper use of this accessor should be like:
Dataset.xsimlab.reset_vars
Dataset.xsimlab.filter_vars
Dataset.xsimlab.run
Dataset.xsimlab.get_output_save_steps

Model
=====
Expand Down
2 changes: 2 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ Enhancements
- Added simulation runtime hooks (:issue:`95`). Hooks can be created by using
either the :func:`~xsimlab.runtime_hook` decorator or the
:class:`~xsimlab.RuntimeHook` class.
- Added some useful properties and methods to the ``xarray.Dataset.xsimlab``
extension (:issue:`103`).

Bug fixes
~~~~~~~~~
Expand Down
62 changes: 47 additions & 15 deletions xsimlab/tests/test_xr_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,17 @@ def test_clock_coords(self):
)
assert set(ds.xsimlab.clock_coords) == {"mclock", "sclock"}

def test_clock_sizes(self):
ds = xr.Dataset(
coords={
"clock1": ("clock1", [0, 1, 2], {self._clock_key: 1}),
"clock2": ("clock2", [0, 2], {self._clock_key: 1}),
"no_clock": ("no_clock", [3, 4]),
}
)

assert ds.xsimlab.clock_sizes == {"clock1": 3, "clock2": 2}

def test_master_clock_dim(self):
attrs = {self._clock_key: 1, self._master_clock_key: 1}
ds = xr.Dataset(coords={"clock": ("clock", [1, 2], attrs)})
Expand All @@ -118,18 +129,34 @@ def test_master_clock_dim(self):
ds = xr.Dataset()
assert ds.xsimlab.master_clock_dim is None

# def test_set_master_clock_dim(self):
# ds = xr.Dataset(coords={'clock': [1, 2], 'clock2': [3, 4]})
def test_nsteps(self):
attrs = {self._clock_key: 1, self._master_clock_key: 1}
ds = xr.Dataset(coords={"clock": ("clock", [1, 2, 3], attrs)})

# ds.xsimlab._set_master_clock_dim('clock')
# assert self._master_clock_key in ds.clock.attrs
assert ds.xsimlab.nsteps == 2

# ds.xsimlab._set_master_clock_dim('clock2')
# assert self._master_clock_key not in ds.clock.attrs
# assert self._master_clock_key in ds.clock2.attrs
ds = xr.Dataset()
assert ds.xsimlab.nsteps == 0

# with pytest.raises(KeyError):
# ds.xsimlab._set_master_clock_dim('invalid_clock')
def test_get_output_save_steps(self):
attrs = {self._clock_key: 1, self._master_clock_key: 1}
ds = xr.Dataset(
coords={
"clock": ("clock", [0, 1, 2, 3, 4], attrs),
"clock1": ("clock1", [0, 2, 4], {self._clock_key: 1}),
"clock2": ("clock2", [0, 4], {self._clock_key: 1}),
}
)

expected = xr.Dataset(
coords={"clock": ("clock", [0, 1, 2, 3, 4], attrs)},
data_vars={
"clock1": ("clock", [True, False, True, False, True]),
"clock2": ("clock", [True, False, False, False, True]),
},
)

xr.testing.assert_identical(ds.xsimlab.get_output_save_steps(), expected)

def test_set_input_vars(self, model, in_dataset):
in_vars = {
Expand Down Expand Up @@ -305,6 +332,17 @@ def test_output_vars(self, model):

assert ds.xsimlab.output_vars == o_vars

def test_output_vars_by_clock(self, model):
o_vars = {("roll", "u_diff"): "clock", ("add", "u_diff"): None}

ds = xs.create_setup(
model=model, clocks={"clock": [0, 2, 4, 6, 8]}, output_vars=o_vars,
)

expected = {"clock": [("roll", "u_diff")], None: [("add", "u_diff")]}

assert ds.xsimlab.output_vars_by_clock == expected

def test_run_safe_mode(self, model, in_dataset):
# safe mode True: ensure model is cloned
_ = in_dataset.xsimlab.run(model=model, safe_mode=True)
Expand Down Expand Up @@ -385,12 +423,6 @@ def initialize(self):
with pytest.raises(TypeError, match=r".*'int'.*"):
in_dataset.xsimlab.run(model=m, validate="all")

def test_run_multi(self):
ds = xr.Dataset()

with pytest.raises(NotImplementedError):
ds.xsimlab.run_multi()


def test_create_setup(model, in_dataset):
expected = xr.Dataset()
Expand Down
110 changes: 87 additions & 23 deletions xsimlab/xr_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .drivers import XarraySimulationDriver
from .model import Model
from .stores import InMemoryOutputStore
from .utils import variables_dict
from .utils import Frozen, variables_dict


@register_dataset_accessor("filter")
Expand Down Expand Up @@ -134,17 +134,31 @@ class SimlabAccessor:
def __init__(self, ds):
self._ds = ds
self._master_clock_dim = None
self._clock_coords = None

@property
def clock_coords(self):
"""Dictionary of :class:`xarray.DataArray` objects corresponding to
clock coordinates.
"""Mapping from clock dimensions to :class:`xarray.DataArray` objects
corresponding to their coordinates.
Cannot be modified directly.
"""
if self._clock_coords is None:
self._clock_coords = {
k: coord
for k, coord in self._ds.coords.items()
if self._clock_key in coord.attrs
}

return Frozen(self._clock_coords)

@property
def clock_sizes(self):
"""Mapping from clock dimensions to lengths.
Cannot be modified directly.
"""
return {
k: coord
for k, coord in self._ds.coords.items()
if self._clock_key in coord.attrs
}
return Frozen({k: coord.size for k, coord in self.clock_coords.items()})

@property
def master_clock_dim(self):
Expand All @@ -156,6 +170,10 @@ def master_clock_dim(self):
:meth:`Dataset.xsimlab.update_clocks`
"""
# it is fine to cache the value here as inconsistency may appear
# only when deleting the master clock coordinate from the dataset,
# which would raise early anyway

if self._master_clock_dim is not None:
return self._master_clock_dim
else:
Expand All @@ -166,6 +184,47 @@ def master_clock_dim(self):
return dim
return None

@property
def master_clock_coord(self):
"""Master clock coordinate (as a :class:`xarray.DataArray` object).
Returns None if no master clock is defined in the dataset.
"""
return self._ds.get(self.master_clock_dim)

@property
def nsteps(self):
"""Number of simulation steps, computed from the master
clock coordinate.
Returns 0 if no master clock is defined in the dataset.
"""
if self.master_clock_dim is None:
return 0
else:
return self._ds[self.master_clock_dim].size - 1

def get_output_save_steps(self):
"""Returns save steps for each clock as boolean values.
Returns
-------
save_steps : :class:`xarray.Dataset`
A new Dataset with boolean data variables for each clock
dimension other than the master clock, where values specify
whether or not to save outputs at every step of a simulation.
"""
ds = Dataset(coords={self.master_clock_dim: self.master_clock_coord})

for clock, coord in self.clock_coords.items():
if clock != self.master_clock_dim:
save_steps = np.in1d(self.master_clock_coord.values, coord.values)
ds[clock] = (self.master_clock_dim, save_steps)

return ds

def _set_clock_coord(self, dim, data):
xr_var = as_variable(data, name=dim)

Expand Down Expand Up @@ -293,7 +352,7 @@ def _set_output_vars(self, model, output_vars, clear=False):
# end of depreciated code block

if not clear:
_output_vars = self.output_vars
_output_vars = {k: v for k, v in self.output_vars.items()}
_output_vars.update(output_vars)
output_vars = _output_vars

Expand All @@ -319,6 +378,10 @@ def _set_output_vars(self, model, output_vars, clear=False):
var_str = ",".join(var_list)
self._set_output_vars_attr(clock, var_str)

# reset clock_coords cache as attributes of those coords
# may have been updated
self._clock_coords = None

def _reset_output_vars(self, model, output_vars):
self._set_output_vars_attr(None, None)

Expand All @@ -333,6 +396,7 @@ def output_vars(self):
``('p_name', 'var_name')`` tuples - as keys and the clock dimension
names (or None) on which to save snapshots as values.
Cannot be modified directly.
"""

def xr_attr_to_dict(attrs, clock):
Expand All @@ -350,7 +414,20 @@ def xr_attr_to_dict(attrs, clock):

o_vars.update(xr_attr_to_dict(self._ds.attrs, None))

return o_vars
return Frozen(o_vars)

@property
def output_vars_by_clock(self):
"""Returns a dictionary of output variables grouped by clock (keys).
Cannot be modified directly.
"""
o_vars = defaultdict(list)

for k, clock in self.output_vars.items():
o_vars[clock].append(k)

return Frozen(dict(o_vars))

def update_clocks(self, model=None, clocks=None, master_clock=None):
"""Set or update clock coordinates.
Expand Down Expand Up @@ -637,19 +714,6 @@ def run(

return driver.run_model()

def run_multi(self):
"""Run multiple models.
Not yet implemented.
See Also
--------
:meth:`xarray.Dataset.xsimlab.run`
"""
# TODO:
raise NotImplementedError()


def create_setup(
model=None,
Expand Down

0 comments on commit d5db973

Please sign in to comment.