Skip to content

Commit

Permalink
Add Model.cache property (#125)
Browse files Browse the repository at this point in the history
* gitignore unreleated

* add Model.cache property

* use public API outside of Model

* update doc
  • Loading branch information
benbovy committed Apr 9, 2020
1 parent 42e603d commit b603d58
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 9 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ wheels/
.installed.cfg
*.egg

# dask
dask-worker-space

# Installer logs
pip-log.txt

Expand Down
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ e.g., for using ``Model`` objects with other interfaces.

Model.state
Model.update_state
Model.cache
Model.update_cache
Model.execute
Model.validate
Expand Down
5 changes: 5 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ Release Notes
v0.5.0 (Unreleased)
-------------------

Enhancements
~~~~~~~~~~~~

- Added :attr:`xsimlab.Model.cache` public property (:issue:`125`).

Bug fixes
~~~~~~~~~

Expand Down
4 changes: 2 additions & 2 deletions xsimlab/drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,14 @@ def _maybe_transpose(dataset, model, check_dims, batch_dim):
ds_transposed = dataset.copy()

for var_key in model.input_vars:
xr_var_name = model._var_cache[var_key]["name"]
xr_var_name = model.cache[var_key]["name"]
xr_var = dataset.get(xr_var_name)

if xr_var is None:
continue

# all valid dimensions in the right order
dims = [list(d) for d in model._var_cache[var_key]["metadata"]["dims"]]
dims = [list(d) for d in model.cache[var_key]["metadata"]["dims"]]
dims += [[dataset.xsimlab.master_clock_dim] + d for d in dims]
if batch_dim is not None:
dims += [[batch_dim] + d for d in dims]
Expand Down
10 changes: 10 additions & 0 deletions xsimlab/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,16 @@ def update_state(
p_names = set([pn for pn, _ in input_vars if pn in self._processes])
self.validate(p_names)

@property
def cache(self):
"""Returns a mapping of model variables and some of their (meta)data cached for
fastpath access.
Mapping keys are in the form of ``('process_name', 'var_name')`` tuples.
"""
return self._var_cache

def update_cache(self, var_key):
"""Update the model's cache for a given model variable.
Expand Down
10 changes: 5 additions & 5 deletions xsimlab/stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _get_var_info(
var_clocks.update({vk: None for vk in model.index_vars})

for var_key, clock in var_clocks.items():
var_cache = model._var_cache[var_key]
var_cache = model.cache[var_key]

# encoding defined at model run
run_encoding = normalize_encoding(
Expand Down Expand Up @@ -179,7 +179,7 @@ def _create_zarr_dataset(
if name is None:
name = var_info["name"]

value = model._var_cache[var_key]["value"]
value = model.cache[var_key]["value"]
clock = var_info["clock"]

dtype = getattr(value, "dtype", np.asarray(value).dtype)
Expand Down Expand Up @@ -251,7 +251,7 @@ def _maybe_resize_zarr_dataset(

zkey = var_info["name"]
zshape = self.zgroup[zkey].shape
value = model._var_cache[var_key]["value"]
value = model.cache[var_key]["value"]
value_shape = list(np.shape(value))

# maybe prepend clock dim (do not resize this dim)
Expand Down Expand Up @@ -292,7 +292,7 @@ def write_output_vars(self, batch: int, step: int, model: Optional[Model] = None

for vk in var_keys:
zkey = self.var_info[vk]["name"]
value = model._var_cache[vk]["value"]
value = model.cache[vk]["value"]

self._maybe_resize_zarr_dataset(model, vk)

Expand Down Expand Up @@ -325,7 +325,7 @@ def write_index_vars(self, model: Optional[Model] = None):
model.update_cache(var_key)

self._create_zarr_dataset(model, var_key, name=vname)
self.zgroup[vname][:] = model._var_cache[var_key]["value"]
self.zgroup[vname][:] = model.cache[var_key]["value"]

def consolidate(self):
zarr.consolidate_metadata(self.zgroup.store)
Expand Down
4 changes: 2 additions & 2 deletions xsimlab/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,13 @@ def test_update_cache(self, model):
model.state[("init_profile", "n_points")] = 10
model.update_cache(("init_profile", "n_points"))

assert model._var_cache[("init_profile", "n_points")]["value"] == 10
assert model.cache[("init_profile", "n_points")]["value"] == 10

# test on demand variables
model.state[("add", "offset")] = 1
model.update_cache(("add", "u_diff"))

assert model._var_cache[("add", "u_diff")]["value"] == 1
assert model.cache[("add", "u_diff")]["value"] == 1

def test_validate(self, model):
model.state[("roll", "shift")] = 2.5
Expand Down

0 comments on commit b603d58

Please sign in to comment.