Skip to content

Commit

Permalink
Reset output variables: avoid update attrs in original dataset (#101)
Browse files Browse the repository at this point in the history
* avoid update attrs in original dataset

* update release notes
  • Loading branch information
benbovy committed Feb 26, 2020
1 parent dfced96 commit 220ea1b
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 10 deletions.
2 changes: 2 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ Bug fixes

- Remove ``attrs`` 19.2.0 depreciation warning (:issue:`68`).
- Fix compatibility with xarray 0.14.1 (:issue:`69`).
- Avoid update in-place attributes in original/input xarray Datasets
(:issue:`101`).

v0.3.0 (30 September 2019)
--------------------------
Expand Down
2 changes: 1 addition & 1 deletion xsimlab/tests/fixture_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def in_dataset():
)

ds["clock"].attrs[svars_key] = "profile__u"
ds["out"].attrs[svars_key] = "roll__u_diff," "add__u_diff"
ds["out"].attrs[svars_key] = "roll__u_diff,add__u_diff"
ds.attrs[svars_key] = "profile__u_opp"

return ds
Expand Down
4 changes: 4 additions & 0 deletions xsimlab/tests/test_xr_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ def test_filter_vars(self, simple_model, in_dataset):
assert sorted(filtered_ds.xsimlab.clock_coords) == ["clock", "out"]
assert filtered_ds.out.attrs[self._output_vars_key] == "roll__u_diff"

# test unchanged attributes in original dataset
assert in_dataset.out.attrs[self._output_vars_key] == "roll__u_diff,add__u_diff"
assert in_dataset.attrs[self._output_vars_key] == "profile__u_opp"

def test_set_output_vars(self, model):
ds = xr.Dataset()
ds["clock"] = (
Expand Down
33 changes: 24 additions & 9 deletions xsimlab/xr_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,26 @@ def _set_input_vars(self, model, input_vars):

self._ds[xr_var_name] = xr_var

def _set_output_vars_attr(self, clock, value):
# avoid update attrs in original dataset

if clock is None:
attrs = self._ds.attrs.copy()
else:
attrs = self._ds[clock].attrs.copy()

if value is None:
attrs.pop(self._output_vars_key, None)
else:
attrs[self._output_vars_key] = value

if clock is None:
self._ds.attrs = attrs
else:
new_coord = self._ds.coords[clock].copy()
new_coord.attrs = attrs
self._ds[clock] = new_coord

def _set_output_vars(self, model, output_vars, clear=False):
# TODO: remove this ugly code (depreciated output_vars format)
o_vars = {}
Expand Down Expand Up @@ -297,18 +317,13 @@ def _set_output_vars(self, model, output_vars, clear=False):

for clock, var_list in clock_vars.items():
var_str = ",".join(var_list)

if clock is None:
self._ds.attrs[self._output_vars_key] = var_str
else:
coord = self.clock_coords[clock]
coord.attrs[self._output_vars_key] = var_str
self._set_output_vars_attr(clock, var_str)

def _reset_output_vars(self, model, output_vars):
self._ds.attrs.pop(self._output_vars_key, None)
self._set_output_vars_attr(None, None)

for coord in self.clock_coords.values():
coord.attrs.pop(self._output_vars_key, None)
for clock in self.clock_coords:
self._set_output_vars_attr(clock, None)

self._set_output_vars(model, output_vars, clear=True)

Expand Down

0 comments on commit 220ea1b

Please sign in to comment.