Skip to content

Commit

Permalink
Fix writing index variables for parallel batches (#123)
Browse files Browse the repository at this point in the history
* Fix writing index variables for parallel batches

Issue when running batches in parallel using a multi-process scheduler:
model state is not returned by execution of dask delayed functions, so
the state was lost.

Fix by calling store.write_index_vars for each model run. It is not a
big deal to write those datasets many times (should be same values
across all simulations).

* clean up
  • Loading branch information
benbovy committed Apr 7, 2020
1 parent 9d7cbc0 commit 339e1a5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
7 changes: 3 additions & 4 deletions xsimlab/drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,7 @@ def run_model(self):
self.store.write_input_xr_dataset()

if self.batch_dim is None:
model = self.model
self._run_one_model(self.dataset, model, parallel=self.parallel)
self._run_one_model(self.dataset, self.model, parallel=self.parallel)

else:
ds_gby_batch = self.dataset.groupby(self.batch_dim)
Expand All @@ -348,8 +347,6 @@ def run_model(self):
if self.parallel:
dask.compute(futures, scheduler=self.scheduler)

self.store.write_index_vars(model=model)

def _run_one_model(self, dataset, model, batch=-1, parallel=False):
"""Run one simulation.
Expand Down Expand Up @@ -406,3 +403,5 @@ def _run_one_model(self, dataset, model, batch=-1, parallel=False):
self.store.write_output_vars(batch, -1, model=model)

model.execute("finalize", rt_context, **execute_kwargs)

self.store.write_index_vars(model=model)
4 changes: 4 additions & 0 deletions xsimlab/tests/test_xr_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,10 @@ def test_run_batch_dim(self, dims, data, clock, parallel, scheduler):
class P:
in_var = xs.variable(dims=[(), "x"])
out_var = xs.variable(dims=[(), "x"], intent="out")
idx_var = xs.index(dims="x")

def initialize(self):
self.idx_var = [0, 1]

def run_step(self):
self.out_var = self.in_var * 2
Expand Down

0 comments on commit 339e1a5

Please sign in to comment.