Skip to content

Commit

Permalink
maybe raise when updating value of static variable
Browse files Browse the repository at this point in the history
Raise when attempt is made during simulation runtime.
Also update existing tests.
  • Loading branch information
benbovy committed Dec 11, 2019
1 parent 1caa149 commit cd413f1
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 19 deletions.
54 changes: 42 additions & 12 deletions xsimlab/drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,29 @@ def _bind_store_to_model(self):
for p_obj in self.model.values():
p_obj.__xsimlab_store__ = self.store

def update_store(self, input_vars):
"""Update the simulation active data store with input variable
values.
def _set_in_store(self, input_vars, check_static=True):
for key in self.model.input_vars:
value = input_vars.get(key)

if value is None:
continue

p_name, var_name = key
var = variables_dict(self.model[p_name].__class__)[var_name]

if check_static and var.metadata.get('static', False):
raise RuntimeError("Cannot set value in store for "
"static variable {!r} defined "
"in process {!r}"
.format(p_name, var_name))

self.store[key] = copy.copy(value)

def initialize_store(self, input_vars):
"""Pre-populate the simulation active data store with input
variable values.
This should be called before the simulation starts.
``input_vars`` is a dictionary where keys are store keys, i.e.,
``(process_name, var_name)`` tuples, and values are the input
Expand All @@ -82,11 +102,17 @@ def update_store(self, input_vars):
inputs are silently ignored.
"""
for key in self.model.input_vars:
value = input_vars.get(key)
self._set_in_store(input_vars, check_static=False)

if value is not None:
self.store[key] = copy.copy(value)
def update_store(self, input_vars):
"""Update the simulation active data store with input variable
values.
Like ``initialize_store``, but here meant to be called during
simulation runtime.
"""
self._set_in_store(input_vars, check_static=True)

def update_output_store(self, output_var_keys):
"""Update the simulation output store (i.e., append new values to the
Expand Down Expand Up @@ -188,19 +214,23 @@ def _get_output_save_steps(self):

return save_steps

def _set_input_vars(self, dataset):
def _get_input_vars(self, dataset):
input_vars = {}

for p_name, var_name in self.model.input_vars:
xr_var_name = p_name + '__' + var_name
xr_var = dataset.get(xr_var_name)

if xr_var is not None:
data = xr_var.data.copy()
data = xr_var.data

if data.ndim == 0:
# convert array to scalar
data = data.item()

self.store[(p_name, var_name)] = data
input_vars[(p_name, var_name)] = data

return input_vars

def _maybe_save_output_vars(self, istep):
# TODO: optimize this for performance
Expand Down Expand Up @@ -303,7 +333,7 @@ def run_model(self):
sim_end=ds_init['_sim_end'].values
)

self._set_input_vars(ds_init)
self.initialize_store(self._get_input_vars(ds_init))

self.model.execute('initialize', runtime_context)

Expand All @@ -314,7 +344,7 @@ def run_model(self):
step_end=ds_step['_clock_end'].values,
step_delta=ds_step['_clock_diff'].values)

self._set_input_vars(ds_step)
self.update_store(self._get_input_vars(ds_step))

self.model.execute('run_step', runtime_context)
self._maybe_save_output_vars(step)
Expand Down
10 changes: 3 additions & 7 deletions xsimlab/tests/test_drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,11 @@ def test_output_save_steps(self, xarray_driver):
(('init_profile', 'n_points'), True),
(('add', 'offset'), False)
])
def test_set_input_vars(self, in_dataset, xarray_driver,
def test_get_input_vars(self, in_dataset, xarray_driver,
var_key, is_scalar):
xarray_driver._set_input_vars(in_dataset)
in_vars = xarray_driver._get_input_vars(in_dataset)

actual = xarray_driver.store[var_key]
actual = in_vars[var_key]
expected = in_dataset['__'.join(var_key)].data

if is_scalar:
Expand All @@ -106,10 +106,6 @@ def test_set_input_vars(self, in_dataset, xarray_driver,
assert_array_equal(actual, expected)
assert not np.isscalar(actual)

# test copy
actual[0] = -9999
assert not np.array_equal(actual, expected)

def test_get_output_dataset(self, in_dataset, xarray_driver):
# regression test: make sure a copy of input dataset is used
out_ds = xarray_driver.run_model()
Expand Down

0 comments on commit cd413f1

Please sign in to comment.