Skip to content

Commit

Permalink
Refactor output_vars in the xarray extension (#85)
Browse files Browse the repository at this point in the history
* update (switch) `output_vars` dict key/values

Keys become the variable names and values become the clock dimension or
None.

For backward compatibility, input dicts with clock dimension or None as
keys are still supported.

* update / fix existing tests

* add depreciation warning + update tests

* black + tweaks

* update docstrings and doc

* update release notes

* doc tweaks
  • Loading branch information
benbovy committed Dec 24, 2019
1 parent a8c2b59 commit 1d995b4
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 131 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ by Jake VanderPlas.
model=model,
clocks={'step': np.arange(9)},
input_vars={'init__pos': ('point_xy', [4, 5])},
output_vars={'step': 'gol__world'}
output_vars={'gol__world': 'step'}
)
output_dataset = input_dataset.xsimlab.run(model=model)
Expand Down
2 changes: 1 addition & 1 deletion doc/about.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ by Jake VanderPlas.
...: model=model,
...: clocks={'step': np.arange(9)},
...: input_vars={'init__pos': ('point_xy', [4, 5])},
...: output_vars={'step': 'gol__world'}
...: output_vars={'gol__world': 'step'}
...: )
...:
...: output_dataset = input_dataset.xsimlab.run(model=model)
Expand Down
20 changes: 13 additions & 7 deletions doc/run_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,20 @@ create a new setup in a very declarative way:
in_ds = xs.create_setup(
model=model2,
clocks={'time': np.linspace(0., 1., 101),
'otime': [0, 0.5, 1]},
clocks={
'time': np.linspace(0., 1., 101),
'otime': [0, 0.5, 1]
},
master_clock='time',
input_vars={'grid': {'length': 1.5, 'spacing': 0.01},
'init': {'loc': 0.3, 'scale': 0.1},
'advect': {'v': 1.}},
output_vars={None: {'grid': 'x'},
'otime': {'profile': 'u'}}
input_vars={
'grid': {'length': 1.5, 'spacing': 0.01},
'init': {'loc': 0.3, 'scale': 0.1},
'advect__v': 1.
},
output_vars={
'grid__x': None,
'profile__u': 'otime'
}
)
A setup consists in:
Expand Down
14 changes: 12 additions & 2 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,20 @@ Breaking changes
~~~~~~~~~~~~~~~~

- Python 3.6 is now the oldest supported version (:issue:`70`).
- The keys of the dictionary returned by
:attr:`xarray.Dataset.xsimlab.output_vars` now correspond to variable names,
and the values are clock dimension labels or ``None`` (previously the
dictionary was formatted the other way around) (:issue:`85`).

Depreciations
~~~~~~~~~~~~~

- Using the ``group`` parameter in ``xsimlab.variable`` and
``xsimlab.on_demand`` is depreciated; use ``groups`` instead.
- Using the ``group`` parameter in :func:`xsimlab.variable` and
:func:`xsimlab.on_demand` is depreciated; use ``groups`` instead.
- Providing a dictionary with clock dimensions or ``None`` as keys to
``output_vars`` in :func:`xarray.Dataset.xsimlab.update_vars()` and
:func:`xsimlab.create_setup()` is depreciated. Variable names should be used
instead (:issue:`85`).

Enhancements
~~~~~~~~~~~~
Expand All @@ -37,6 +45,8 @@ Enhancements
input xarray Datasets to match those defined in model variables (:issue:`76`).
This is optional and controlled by the parameter ``check_dims`` added
to :func:`xarray.Dataset.xsimlab.run`.
- More consistent dictionary format for output variables in the xarray
extension (:issue:`85`).

Bug fixes
~~~~~~~~~
Expand Down
52 changes: 18 additions & 34 deletions xsimlab/drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,20 +242,16 @@ def _get_output_save_steps(self):
"""
save_steps = {}

for clock in self.output_vars:
if clock is None:
continue

elif clock == self.master_clock_dim:
save_steps[clock] = np.ones_like(
self.dataset[self.master_clock_dim].values, dtype=bool
)
master_coord = self.dataset[self.master_clock_dim]

for clock, coord in self.dataset.xsimlab.clock_coords.items():
if clock == self.master_clock_dim:
save_steps[clock] = np.ones_like(coord.values, dtype=bool)
else:
save_steps[clock] = np.in1d(
self.dataset[self.master_clock_dim].values,
self.dataset[clock].values,
)
save_steps[clock] = np.in1d(master_coord.values, coord.values)

save_steps[None] = np.zeros_like(master_coord.values, dtype=bool)
save_steps[None][-1] = True

return save_steps

Expand Down Expand Up @@ -306,17 +302,13 @@ def _get_input_vars(self, dataset):
return input_vars

def _maybe_save_output_vars(self, istep):
# TODO: optimize this for performance
for clock, var_keys in self.output_vars.items():
save_output = (
clock is None
and istep == -1
or clock is not None
and self.output_save_steps[clock][istep]
)
var_list = []

for key, clock in self.output_vars.items():
if self.output_save_steps[clock][istep]:
var_list.append(key)

if save_output:
self.update_output_store(var_keys)
self.update_output_store(var_list)

def _to_xr_variable(self, key, clock):
"""Convert an output variable to a xarray.Variable object.
Expand Down Expand Up @@ -359,25 +351,17 @@ def _get_output_dataset(self):
"""Return a new dataset as a copy of the input dataset updated with
output variables.
"""
from .xr_accessor import SimlabAccessor

xr_vars = {}

for clock, vars in self.output_vars.items():
for key in vars:
var_name = "__".join(key)
xr_vars[var_name] = self._to_xr_variable(key, clock)
for key, clock in self.output_vars.items():
var_name = "__".join(key)
xr_vars[var_name] = self._to_xr_variable(key, clock)

out_ds = self.dataset.copy()
out_ds.update(xr_vars)

# remove output_vars attributes in output dataset
for clock in self.output_vars:
if clock is None:
attrs = out_ds.attrs
else:
attrs = out_ds[clock].attrs
attrs.pop(SimlabAccessor._output_vars_key)
out_ds.xsimlab._reset_output_vars(self.model, {})

return out_ds

Expand Down
1 change: 1 addition & 0 deletions xsimlab/tests/test_drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def test_output_save_steps(self, xarray_driver):
expected = {
"clock": np.array([True, True, True, True, True]),
"out": np.array([True, False, True, False, True]),
None: np.array([False, False, False, False, True]),
}

assert xarray_driver.output_save_steps.keys() == expected.keys()
Expand Down
69 changes: 36 additions & 33 deletions xsimlab/tests/test_xr_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def test_update_vars(self, model, in_dataset):
ds = in_dataset.xsimlab.update_vars(
model=model,
input_vars={("roll", "shift"): 2},
output_vars={"out": ("profile", "u")},
output_vars={("profile", "u"): "out"},
)

assert not ds["roll__shift"].equals(in_dataset["roll__shift"])
Expand Down Expand Up @@ -258,46 +258,48 @@ def test_set_output_vars(self, model):
ds["out"] = ("out", [0, 4, 8], {self._clock_key: 1})
ds["not_a_clock"] = ("not_a_clock", [0, 1])

with pytest.raises(KeyError) as excinfo:
ds.xsimlab._set_output_vars(model, None, [("invalid", "var")])
assert "not valid key(s)" in str(excinfo.value)
with pytest.raises(KeyError, match=r".*not valid key.*"):
ds.xsimlab._set_output_vars(model, {("invalid", "var"): None})

ds.xsimlab._set_output_vars(model, None, [("profile", "u_opp")])
ds.xsimlab._set_output_vars(model, {("profile", "u_opp"): None})
assert ds.attrs[self._output_vars_key] == "profile__u_opp"

ds.xsimlab._set_output_vars(
model, "out", [("roll", "u_diff"), ("add", "u_diff")]
model, {("roll", "u_diff"): "out", ("add", "u_diff"): "out"}
)
expected = "roll__u_diff,add__u_diff"
assert ds["out"].attrs[self._output_vars_key] == expected

with pytest.raises(ValueError) as excinfo:
ds.xsimlab._set_output_vars(model, "not_a_clock", [("profile", "u")])
assert "not a valid clock" in str(excinfo.value)
with pytest.raises(ValueError, match=r".not a valid clock.*"):
ds.xsimlab._set_output_vars(model, {("profile", "u"): "not_a_clock"})

with pytest.warns(FutureWarning):
ds.xsimlab._set_output_vars(model, {None: ("profile", "u_opp")})

with pytest.warns(FutureWarning):
ds.xsimlab._set_output_vars(model, {"out": ("profile", "u_opp")})

def test_output_vars(self, model):
ds = xr.Dataset()
ds["clock"] = (
"clock",
[0, 2, 4, 6, 8],
{self._clock_key: 1, self._master_clock_key: 1},
)
ds["out"] = ("out", [0, 4, 8], {self._clock_key: 1})
# snapshot clock with no output variable (attribute) set
ds["out2"] = ("out2", [0, 8], {self._clock_key: 1})
o_vars = {
("profile", "u_opp"): None,
("profile", "u"): "clock",
("roll", "u_diff"): "out",
("add", "u_diff"): "out",
}

ds.xsimlab._set_output_vars(model, None, [("profile", "u_opp")])
ds.xsimlab._set_output_vars(model, "clock", [("profile", "u")])
ds.xsimlab._set_output_vars(
model, "out", [("roll", "u_diff"), ("add", "u_diff")]
ds = xs.create_setup(
model=model,
clocks={
"clock": [0, 2, 4, 6, 8],
"out": [0, 4, 8],
# snapshot clock with no output variable
"out2": [0, 8],
},
master_clock="clock",
output_vars=o_vars,
)

expected = {
None: [("profile", "u_opp")],
"clock": [("profile", "u")],
"out": [("roll", "u_diff"), ("add", "u_diff")],
}
assert ds.xsimlab.output_vars == expected
assert ds.xsimlab.output_vars == o_vars

def test_run_safe_mode(self, model, in_dataset):
# safe mode True: ensure model is cloned
Expand All @@ -321,7 +323,7 @@ class P:
model=m,
clocks={"clock": [1, 2]},
input_vars={"p__var": (("y", "x"), arr)},
output_vars={None: ["p__var"]},
output_vars={"p__var": None},
)

out_ds = in_ds.xsimlab.run(model=m, check_dims=None)
Expand All @@ -336,7 +338,7 @@ class P:
np.testing.assert_array_equal(actual, arr)
np.testing.assert_array_equal(m.p.var, arr.transpose())

in_ds2 = in_ds.xsimlab.update_vars(model=m, output_vars={"clock": ["p__var"]})
in_ds2 = in_ds.xsimlab.update_vars(model=m, output_vars={"p__var": "clock"})
# TODO: fix update output vars time-independet -> dependent
# currently need the workaround below
in_ds2.attrs = {}
Expand Down Expand Up @@ -405,9 +407,10 @@ def test_create_setup(model, in_dataset):
clocks={"clock": [0, 2, 4, 6, 8], "out": [0, 4, 8]},
master_clock="clock",
output_vars={
"clock": "profile__u",
"out": [("roll", "u_diff"), ("add", "u_diff")],
None: {"profile": "u_opp"},
"profile__u": "clock",
("roll", "u_diff"): "out",
("add", "u_diff"): "out",
"profile": {"u_opp": None},
},
)
xr.testing.assert_identical(ds, in_dataset)
2 changes: 1 addition & 1 deletion xsimlab/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _as_group_tuple(groups, group):

if group is not None:
warnings.warn(
"Setting variable group using `group` is depreciated; " "use `groups`.",
"Setting variable group using `group` is depreciated; use `groups`.",
FutureWarning,
stacklevel=2,
)
Expand Down

0 comments on commit 1d995b4

Please sign in to comment.