Skip to content

Commit

Permalink
Refactor xarray simulation driver and zarr store internals (#114)
Browse files Browse the repository at this point in the history
* refactor xarray driver and zarr store internals

* black
  • Loading branch information
benbovy committed Mar 19, 2020
1 parent c59ff44 commit 24ec282
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 88 deletions.
150 changes: 90 additions & 60 deletions xsimlab/drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,77 @@ def run_model(self):
raise NotImplementedError()


def _check_missing_master_clock(dataset):
if dataset.xsimlab.master_clock_dim is None:
raise ValueError("Missing master clock dimension / coordinate")


def _check_missing_inputs(dataset, model):
"""Check if all model inputs have their corresponding variables
in the input Dataset.
"""
missing_xr_vars = []

for p_name, var_name in model.input_vars:
xr_var_name = p_name + "__" + var_name

if xr_var_name not in dataset:
missing_xr_vars.append(xr_var_name)

if missing_xr_vars:
raise KeyError(f"Missing variables {missing_xr_vars} in Dataset")


def _get_all_active_hooks(hooks):
"""Get all active runtime hooks (i.e, provided as argument, activated from
context manager or glabally registered) and return them grouped by runtime
event.
"""
active_hooks = set(hooks) | RuntimeHook.active

return group_hooks(flatten_hooks(active_hooks))


def _generate_runtime_datasets(dataset):
"""Create xarray Dataset objects that will be used during runtime of one
simulation.
Return a 2-length tuple where the 1st item is a Dataset used
at the initialize stage and the 2st item is a DatasetGroupBy for
iteration through run steps.
Runtime data is added to those datasets.
"""
mclock_dim = dataset.xsimlab.master_clock_dim
mclock_coord = dataset[mclock_dim]

init_data_vars = {
"_sim_start": mclock_coord[0],
"_nsteps": dataset.xsimlab.nsteps,
"_sim_end": mclock_coord[-1],
}

ds_init = dataset.assign(init_data_vars).drop_dims(mclock_dim)

step_data_vars = {
"_clock_start": mclock_coord,
"_clock_end": mclock_coord.shift({mclock_dim: 1}),
"_clock_diff": mclock_coord.diff(mclock_dim, label="lower"),
}

ds_all_steps = (
dataset.drop(list(ds_init.data_vars.keys()), errors="ignore")
.isel({mclock_dim: slice(0, -1)})
.assign(step_data_vars)
)

ds_gby_steps = ds_all_steps.groupby(mclock_dim)

return ds_init, ds_gby_steps


class XarraySimulationDriver(BaseSimulationDriver):
"""Simulation driver using xarray.Dataset objects as I/O.
Expand All @@ -173,22 +244,23 @@ def __init__(
self,
dataset,
model,
state,
store,
encoding,
state=None,
store=None,
encoding=None,
check_dims=CheckDimsOption.STRICT,
validate=ValidateOption.INPUTS,
hooks=None,
):
_check_missing_master_clock(dataset)
_check_missing_inputs(dataset, model)

self.dataset = dataset
self.model = model

super(XarraySimulationDriver, self).__init__(model, state)

if self.dataset.xsimlab.master_clock_dim is None:
raise ValueError("Missing master clock dimension / coordinate")
if state is None:
state = {}

self._check_missing_model_inputs()
super(XarraySimulationDriver, self).__init__(model, state)

if check_dims is not None:
check_dims = CheckDimsOption(check_dims)
Expand All @@ -201,26 +273,12 @@ def __init__(
self._validate_option = validate

if hooks is None:
hooks = set()
hooks = set(hooks) | RuntimeHook.active
self._hooks = group_hooks(flatten_hooks(hooks))

self.store = ZarrSimulationStore(dataset, model, store, encoding)

def _check_missing_model_inputs(self):
"""Check if all model inputs have their corresponding variables
in the input Dataset.
"""
missing_xr_vars = []

for p_name, var_name in self.model.input_vars:
xr_var_name = p_name + "__" + var_name
hooks = []
self.hooks = _get_all_active_hooks(hooks)

if xr_var_name not in self.dataset:
missing_xr_vars.append(xr_var_name)

if missing_xr_vars:
raise KeyError(f"Missing variables {missing_xr_vars} in Dataset")
self.store = ZarrSimulationStore(
dataset, model, zobject=store, encoding=encoding
)

def _maybe_transpose(self, xr_var, p_name, var_name):
var = variables_dict(self.model[p_name].__class__)[var_name]
Expand Down Expand Up @@ -267,34 +325,6 @@ def _get_input_vars(self, dataset):

return input_vars

def _get_runtime_datasets(self):
mclock_dim = self.dataset.xsimlab.master_clock_dim
mclock_coord = self.dataset[mclock_dim]

init_data_vars = {
"_sim_start": mclock_coord[0],
"_nsteps": self.dataset.xsimlab.nsteps,
"_sim_end": mclock_coord[-1],
}

ds_init = self.dataset.assign(init_data_vars).drop_dims(mclock_dim)

step_data_vars = {
"_clock_start": mclock_coord,
"_clock_end": mclock_coord.shift({mclock_dim: 1}),
"_clock_diff": mclock_coord.diff(mclock_dim, label="lower"),
}

ds_all_steps = (
self.dataset.drop(list(ds_init.data_vars.keys()), errors="ignore")
.isel({mclock_dim: slice(0, -1)})
.assign(step_data_vars)
)

ds_gby_steps = ds_all_steps.groupby(mclock_dim)

return ds_init, ds_gby_steps

def _maybe_validate_inputs(self, input_vars):
p_names = set([v[0] for v in input_vars])

Expand Down Expand Up @@ -338,7 +368,7 @@ def run_model(self):
"""
self.store.write_input_xr_dataset()
ds_init, ds_gby_steps = self._get_runtime_datasets()
ds_init, ds_gby_steps = _generate_runtime_datasets(self.dataset)

validate_all = self._validate_option is ValidateOption.ALL

Expand All @@ -353,7 +383,7 @@ def run_model(self):
self._maybe_validate_inputs(in_vars)

self.model.execute(
"initialize", runtime_context, hooks=self._hooks, validate=validate_all,
"initialize", runtime_context, hooks=self.hooks, validate=validate_all,
)

for step, (_, ds_step) in enumerate(ds_gby_steps):
Expand All @@ -370,21 +400,21 @@ def run_model(self):
self._maybe_validate_inputs(in_vars)

self.model.execute(
"run_step", runtime_context, hooks=self._hooks, validate=validate_all,
"run_step", runtime_context, hooks=self.hooks, validate=validate_all,
)

self.store.write_output_vars(step)

self.model.execute(
"finalize_step",
runtime_context,
hooks=self._hooks,
hooks=self.hooks,
validate=validate_all,
)

self.store.write_output_vars(-1)
self.store.write_index_vars()

self.model.execute("finalize", runtime_context, hooks=self._hooks)
self.model.execute("finalize", runtime_context, hooks=self.hooks)

return self._get_output_dataset()
4 changes: 2 additions & 2 deletions xsimlab/stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def __init__(
self,
dataset: xr.Dataset,
model: Model,
zobject: Union[zarr.Group, MutableMapping, str, None],
encoding: Union[Dict[str, Dict[str, Any]], None],
zobject: Union[zarr.Group, MutableMapping, str, None] = None,
encoding: Union[Dict[str, Dict[str, Any]], None] = None,
):
self.dataset = dataset
self.model = model
Expand Down
17 changes: 6 additions & 11 deletions xsimlab/tests/test_drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ def base_driver(model):

@pytest.fixture
def xarray_driver(in_dataset, model):
state = {}
return XarraySimulationDriver(in_dataset, model, state, None, None)
return XarraySimulationDriver(in_dataset, model)


def test_runtime_context():
Expand Down Expand Up @@ -88,17 +87,13 @@ def test_run_model(self, base_driver):

class TestXarraySimulationDriver:
def test_constructor(self, in_dataset, model):
state = {}

invalid_ds = in_dataset.drop("clock")
with pytest.raises(ValueError) as excinfo:
XarraySimulationDriver(invalid_ds, model, state, None, None)
assert "Missing master clock" in str(excinfo.value)
with pytest.raises(ValueError, match=r"Missing master clock.*"):
XarraySimulationDriver(invalid_ds, model)

invalid_ds = in_dataset.drop("init_profile__n_points")
with pytest.raises(KeyError) as excinfo:
XarraySimulationDriver(invalid_ds, model, state, None, None)
assert "Missing variables" in str(excinfo.value)
with pytest.raises(KeyError, match=r"Missing variables.*"):
XarraySimulationDriver(invalid_ds, model)

@pytest.mark.parametrize(
"value,is_scalar", [(1, True), (("x", [1, 1, 1, 1, 1]), False)]
Expand Down Expand Up @@ -142,7 +137,7 @@ def run_step(self, arg):

m = model.update_processes({"p": P})

driver = XarraySimulationDriver(in_dataset, m, {}, None, None)
driver = XarraySimulationDriver(in_dataset, m)

with pytest.raises(KeyError, match="'not_a_runtime_arg'"):
driver.run_model()
19 changes: 9 additions & 10 deletions xsimlab/tests/test_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ class TestZarrSimulationStore:
"zobject", [None, mkdtemp(), zarr.MemoryStore(), zarr.group()]
)
def test_constructor(self, in_dataset, model, zobject):
out_store = ZarrSimulationStore(in_dataset, model, zobject, None)
out_store = ZarrSimulationStore(in_dataset, model)

assert out_store.zgroup.store is not None

def test_write_input_xr_dataset(self, in_dataset, model):
out_store = ZarrSimulationStore(in_dataset, model, None, None)
out_store = ZarrSimulationStore(in_dataset, model)

out_store.write_input_xr_dataset()
ds = xr.open_zarr(out_store.zgroup.store, chunks=None)
Expand All @@ -40,7 +40,7 @@ def test_write_input_xr_dataset(self, in_dataset, model):

def test_write_output_vars(self, in_dataset, model):
_bind_state(model)
out_store = ZarrSimulationStore(in_dataset, model, None, None)
out_store = ZarrSimulationStore(in_dataset, model)

model.state[("profile", "u")] = np.array([1.0, 2.0, 3.0])
model.state[("roll", "u_diff")] = np.array([-1.0, 1.0, 0.0])
Expand Down Expand Up @@ -70,7 +70,7 @@ def test_write_output_vars(self, in_dataset, model):

def test_write_output_vars_error(self, in_dataset, model):
_bind_state(model)
out_store = ZarrSimulationStore(in_dataset, model, None, None)
out_store = ZarrSimulationStore(in_dataset, model)

model.state[("profile", "u")] = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
model.state[("roll", "u_diff")] = np.array([-1.0, 1.0, 0.0])
Expand All @@ -81,7 +81,7 @@ def test_write_output_vars_error(self, in_dataset, model):

def test_write_index_vars(self, in_dataset, model):
_bind_state(model)
out_store = ZarrSimulationStore(in_dataset, model, None, None)
out_store = ZarrSimulationStore(in_dataset, model)

model.state[("init_profile", "x")] = np.array([1.0, 2.0, 3.0])

Expand All @@ -102,7 +102,7 @@ class P:
)

_bind_state(model)
out_store = ZarrSimulationStore(in_ds, model, None, None)
out_store = ZarrSimulationStore(in_ds, model)

for step, size in zip([0, 1, 2], [1, 3, 2]):
model.state[("p", "arr")] = np.ones(size)
Expand Down Expand Up @@ -138,8 +138,7 @@ def _get_v2(self):
out_store = ZarrSimulationStore(
in_ds,
model,
None,
{"p__v2": {"fill_value": -1}, "p__v3": {"compressor": None}},
encoding={"p__v2": {"fill_value": -1}, "p__v3": {"compressor": None}},
)

model.state[("p", "v1")] = [0]
Expand All @@ -155,7 +154,7 @@ def _get_v2(self):

def test_open_as_xr_dataset(self, in_dataset, model):
_bind_state(model)
out_store = ZarrSimulationStore(in_dataset, model, None, None)
out_store = ZarrSimulationStore(in_dataset, model)

model.state[("profile", "u")] = np.array([1.0, 2.0, 3.0])
model.state[("roll", "u_diff")] = np.array([-1.0, 1.0, 0.0])
Expand All @@ -168,7 +167,7 @@ def test_open_as_xr_dataset(self, in_dataset, model):

def test_open_as_xr_dataset_chunks(self, in_dataset, model):
_bind_state(model)
out_store = ZarrSimulationStore(in_dataset, model, mkdtemp(), None)
out_store = ZarrSimulationStore(in_dataset, model, zobject=mkdtemp())

model.state[("profile", "u")] = np.array([1.0, 2.0, 3.0])
model.state[("roll", "u_diff")] = np.array([-1.0, 1.0, 0.0])
Expand Down
7 changes: 2 additions & 5 deletions xsimlab/xr_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,14 +725,11 @@ def run(
if safe_mode:
model = model.clone()

state = {}

driver = XarraySimulationDriver(
self._ds,
model,
state,
store,
encoding,
store=store,
encoding=encoding,
check_dims=check_dims,
validate=validate,
hooks=hooks,
Expand Down

0 comments on commit 24ec282

Please sign in to comment.