diff --git a/doc/whats_new.rst b/doc/whats_new.rst index e0781ca..76f7875 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -15,6 +15,8 @@ v0.6.0 (Unreleased) This will be set to the main clock when storing the dataset. - Changed default ``fill_value`` in the zarr stores to maximum dtype value for integer dtypes and ``np.nan`` for floating-point variables. + - Added custom dependencies as option at model creation e.g. + ``xs.Model({"a":A,"b":B},custom_dependencies={"a":"b"}) v0.5.0 (26 January 2021) ------------------------ diff --git a/xsimlab/model.py b/xsimlab/model.py index 135aaf6..e4a9200 100644 --- a/xsimlab/model.py +++ b/xsimlab/model.py @@ -401,7 +401,7 @@ def get_processes_to_validate(self): return {k: list(v) for k, v in processes_to_validate.items()} - def get_process_dependencies(self): + def get_process_dependencies(self, custom_dependencies={}): """Return a dictionary where keys are each process of the model and values are lists of the names of dependent processes (or empty lists for processes that have no dependencies). @@ -423,6 +423,10 @@ def get_process_dependencies(self): ] ) + # actually add custom dependencies + for p_name, deps in custom_dependencies.items(): + self._dep_processes[p_name].update(deps) + for p_name, p_obj in self._processes_obj.items(): for var in filter_variables(p_obj, intent=VarIntent.OUT).values(): if var.metadata["var_type"] == VarType.ON_DEMAND: @@ -455,6 +459,7 @@ def _sort_processes(self): """ ordered = [] + self._deps_dict = {p: set() for p in self._dep_processes} # Nodes whose descendents have been completely explored. # These nodes are guaranteed to not be part of a cycle. @@ -484,18 +489,19 @@ def _sort_processes(self): # Add direct descendants of cur to nodes stack next_nodes = [] for nxt in self._dep_processes[cur]: - if nxt not in completed: - if nxt in seen: - # Cycle detected! - cycle = [nxt] - while nodes[-1] != nxt: - cycle.append(nodes.pop()) + if nxt in seen: + # Cycle detected! + cycle = [nxt] + while nodes[-1] != nxt: cycle.append(nodes.pop()) - cycle.reverse() - cycle = "->".join(cycle) - raise RuntimeError( - f"Cycle detected in process graph: {cycle}" - ) + cycle.append(nodes.pop()) + cycle.reverse() + cycle = "->".join(cycle) + raise RuntimeError(f"Cycle detected in process graph: {cycle}") + if nxt in completed: + self._deps_dict[cur].add(nxt) + self._deps_dict[cur].update(self._deps_dict[nxt]) + else: next_nodes.append(nxt) if next_nodes: @@ -507,8 +513,148 @@ def _sort_processes(self): completed.add(cur) seen.remove(cur) nodes.pop() + return ordered + def _strict_order_check(self): + """ + IMPORTANT: _sort_processes should be run first + checks if all inout variables and corresponding in variables are explicitly set in the dependencies + Out variables always come first, since the get_process_dependencies checks for that. + A well-behaved graph looks like: ``in0->inout1->in1->inout2->in2`` + """ + # create dictionaries with all inout variables and input variables + inout_dict = {} # dict of {key:{p1_name,p2_name}} for inout variables + # TODO: improve this: the aim is to create a {key:{p1,p2,p3}} dict, + # where p1,p2,p3 are process names that have the key var as inout, resp. in vars + # some problems are that we can have on_demand and state varibles, + # that key can return a tuple or list, + for p_name, p_obj in self._processes_obj.items(): + # create {key:{p1_name,p2_name}} dicts for in and inout vars. + for var in filter_variables(p_obj, intent=VarIntent.INOUT).values(): + state_key, od_key = self._get_var_key(p_name, var) + if state_key is not None: + if not state_key in inout_dict: + inout_dict[state_key] = {p_name} + else: + inout_dict[state_key].add(p_name) + if od_key is not None: + if not od_key in inout_dict: + inout_dict[od_key] = {p_name} + else: + inout_dict[od_key].add(p_name) + + in_dict = {key: set() for key in inout_dict} + for p_name, p_obj in self._processes_obj.items(): + for var in filter_variables(p_obj, intent=VarIntent.IN).values(): + state_key, od_key = self._get_var_key(p_name, var) + if state_key in in_dict: + in_dict[state_key].add(p_name) + if od_key in in_dict: + in_dict[od_key].add(p_name) + + # filter out variables that do not need to be checked (without inputs): + # inout_dict = {k: v for k, v in inout_dict.items() if k in in_dict} + + for key, inout_ps in inout_dict.items(): + in_ps = in_dict[key] + + verified_ios = [] + # now we only have to search and verify all inout variables + for io_p in inout_ps: + io_stack = [io_p] + while io_stack: + cur = io_stack[-1] + if cur in verified_ios: + io_stack.pop() + continue + + child_ios = self._deps_dict[cur].intersection(inout_ps - {cur}) + if child_ios: + if child_ios == set(verified_ios): + child_ins = in_ps.intersection(self._deps_dict[cur]) + # verify that all children have the previous io as + # dependency + problem_children = {} + for child_in in child_ins: + # we want to list all processes that should + # depend on the previous + # io-io + # / + # in + if not verified_ios[-1] in self._deps_dict[child_in]: + problem_children[child_in] = [ + p + for p in verified_ios + if p not in self._deps_dict[child_in] + ] + if problem_children: + raise RuntimeError( + f"While checking {key}, {cur} updates it" + f" and depends on some processes that use" + f" it, but they do not depend on {verified_ios[-1]}" + f" place them somewhere between or before " + f"their values: {problem_children}" + ) + # we can now safely remove these in nodes + in_ps -= child_ins + verified_ios.append(cur) + io_stack.pop() + elif child_ios - set(verified_ios): + io_stack.extend(child_ios) + else: + # the problem here is that + # io-..-io + # \ + # io + problem_ios = [ + p for p in verified_ios if p not in child_ios + ] + raise RuntimeError( + f"while checking {key}, order of inout process " + f"{cur} compared to {problem_ios} could not be " + f"established" + ) + else: + # we are at the bottom inout process: remove in + # variables from the set + # this can only happen if we are the first process at + # the bottom + if verified_ios: + # the problem here is + # io->..->io + # / + # io + problem_ios = [ + p for p in verified_ios if cur not in self._deps_dict[p] + ] + raise RuntimeError( + f"While checking {key}, inout process " + f"{verified_ios[-1]} has two branch dependencies." + f" Place {cur} before or somewhere between " + f"{verified_ios[:-1]}" + ) + in_ps -= self._deps_dict[cur] + verified_ios.append(cur) + io_stack.pop() + + # we finished all inout, and inputs that are descendants of inout + # vars, so all remaining input vars should depend on the last inout + # var + problem_ins = {} + for p in in_ps: + if not verified_ios[-1] in self._deps_dict[p]: + problem_ins[p] = [ + prob for prob in verified_ios if prob not in self._deps_dict[p] + ] + + if problem_ins: + raise RuntimeError( + f"while checking {key}, some input processes do not depend " + f"on {verified_ios[-1]}, with all inout processes {verified_ios}" + f" place them somewhere in between or before their values: {problem_ins}" + ) + def get_sorted_processes(self): self._sorted_processes = OrderedDict( [(p_name, self._processes_obj[p_name]) for p_name in self._sort_processes()] @@ -523,8 +669,9 @@ class Model(AttrMapping): This collection is ordered such that the computational flow is consistent with process inter-dependencies. - Ordering doesn't need to be explicitly provided ; it is dynamically - computed using the processes interfaces. + Ordering doesn't always need to be explicitly provided ; it is dynamically + computed using the processes interfaces. For other cases, custom + dependencies can be supplied. Processes interfaces are also used for automatically retrieving the model inputs, i.e., all the variables that require setting a @@ -534,17 +681,25 @@ class Model(AttrMapping): active = [] - def __init__(self, processes): + def __init__(self, processes, custom_dependencies={}, strict_order_check=False): """ Parameters ---------- processes : dict - Dictionnary with process names as keys and classes (decorated with + Dictionary with process names as keys and classes (decorated with :func:`process`) as values. + custom_dependencies : dict + Dictionary of custom dependencies. + keys are process names and values iterable of process names that it + depends on. + strict_order_check : bool + if True, aggresively check for correct ordering. (default: False) + For a variable with processes for which it is an inout variable, it + should look like: ``ins0->inout1->ins1->inout2->ins2`` Raises ------ - :exc:`NoteAProcessClassError` + :exc:`NotAProcessClassError` If values in ``processes`` are not classes decorated with :func:`process`. @@ -572,9 +727,21 @@ def __init__(self, processes): self._processes_to_validate = builder.get_processes_to_validate() - self._dep_processes = builder.get_process_dependencies() + # clean custom dependencies + self._custom_dependencies = {} + for p_name, c_deps in custom_dependencies.items(): + c_deps = {c_deps} if isinstance(c_deps, str) else set(c_deps) + self._custom_dependencies[p_name] = c_deps + + self._dep_processes = builder.get_process_dependencies( + self._custom_dependencies + ) self._processes = builder.get_sorted_processes() + self._strict_order_check = strict_order_check + if self._strict_order_check: + builder._strict_order_check() + super(Model, self).__init__(self._processes) self._initialized = True @@ -1035,13 +1202,18 @@ def clone(self): Returns ------- cloned : Model - New Model instance with the same processes. + New Model instance with the same processes, custom dependencies and + checking behaviour. """ processes_cls = {k: type(obj) for k, obj in self._processes.items()} - return type(self)(processes_cls) + return type(self)( + processes_cls, self._custom_dependencies, self._strict_order_check + ) - def update_processes(self, processes): + def update_processes( + self, processes, custom_dependencies={}, strict_order_check=None + ): """Add or replace processe(s) in this model. Parameters @@ -1049,23 +1221,42 @@ def update_processes(self, processes): processes : dict Dictionnary with process names as keys and process classes as values. + custom_dependencies : dict + Dictionary with addtional custom dependencies + strict_order_check : bool + Whether to strictly check for ordering in the new model, by default + retains the setting. Returns ------- updated : Model - New Model instance with updated processes. + New Model instance with updated processes, custom dependencies and + checking behaviour. """ processes_cls = {k: type(obj) for k, obj in self._processes.items()} processes_cls.update(processes) - return type(self)(processes_cls) + + new_c_deps = {p: deps for p, deps in self._custom_dependencies.items()} + for p, deps in custom_dependencies.items(): + deps = {deps} if isinstance(deps, str) else set(deps) + if p in new_c_deps: + new_c_deps[p].update(deps) + else: + new_c_deps[p] = deps + + if strict_order_check is None: + strict_order_check = self._strict_order_check + + return type(self)(processes_cls, new_c_deps, strict_order_check) def drop_processes(self, keys): - """Drop processe(s) from this model. + """Drop processe(s) from this model. Also establishes new custom + dependencies if they would be lost. Parameters ---------- - keys : str or list of str + keys : str or iterable of str Name(s) of the processes to drop. Returns @@ -1074,13 +1265,52 @@ def drop_processes(self, keys): New Model instance with dropped processes. """ - if isinstance(keys, str): - keys = [keys] + keys = {keys} if isinstance(keys, str) else set(keys) processes_cls = { k: type(obj) for k, obj in self._processes.items() if k not in keys } - return type(self)(processes_cls) + + # we also should check for chains of deps e.g. + # a->b->c->d->e where {b,c,d} are removed + # then we have a->e left over. + # perform a depth-first search on custom dependencies + # and let the custom deps propagate forward + completed = set() + for key in self._custom_dependencies: + if key in completed: + continue + key_stack = [key] + while key_stack: + cur = key_stack[-1] + if cur in completed: + key_stack.pop() + continue + + # if we have custom dependencies that are removed + # and are fully traversed, add their deps to the current + child_keys = keys.intersection(self._custom_dependencies[cur]) + if child_keys.issubset(completed): + # all children are added, so we are safe + self._custom_dependencies[cur].update( + *[ + self._custom_dependencies[child_key] + for child_key in child_keys + ] + ) + self._custom_dependencies[cur] -= child_keys + completed.add(cur) + key_stack.pop() + else: # if child_keys - completed: + # we need to search deeper: add to the stack. + key_stack.extend([k for k in child_keys - completed]) + + # now also remove keys from custom deps + for key in keys: + if key in self._custom_dependencies: + del self._custom_dependencies[key] + + return type(self)(processes_cls, self._custom_dependencies) def __eq__(self, other): if not isinstance(other, self.__class__): diff --git a/xsimlab/tests/test_model.py b/xsimlab/tests/test_model.py index 9228db5..103a596 100644 --- a/xsimlab/tests/test_model.py +++ b/xsimlab/tests/test_model.py @@ -148,6 +148,36 @@ def test_get_process_dependencies(self, model): # order of dependencies is not ensured assert set(actual[p_name]) == set(expected[p_name]) + def test_get_process_dependencies_custom(self, model): + @xs.process + class A: + pass + + @xs.process + class B: + pass + + @xs.process + class C: + pass + + actual = xs.Model( + {"a": A, "b": B}, custom_dependencies={"a": "b"} + ).dependent_processes + expected = {"a": ["b"], "b": []} + + for p_name in expected: + assert set(actual[p_name]) == set(expected[p_name]) + + # also test with a list + actual = xs.Model( + {"a": A, "b": B, "c": C}, custom_dependencies={"a": ["b", "c"]} + ).dependent_processes + expected = {"a": ["b", "c"], "b": [], "c": []} + + for p_name in expected: + assert set(actual[p_name]) == set(expected[p_name]) + @pytest.mark.parametrize( "p_name,dep_p_name", [ @@ -175,6 +205,135 @@ class Bar: with pytest.raises(RuntimeError, match=r"Cycle detected.*"): xs.Model({"foo": Foo, "bar": Bar}) + def test_strict_check(self): + # also give the variable different names + @xs.process + class Inout1: + v = xs.variable(intent="inout") + + @xs.process + class Inout2: + va = xs.foreign(Inout1, "v", intent="inout") + + @xs.process + class Inout3: + var = xs.foreign(Inout1, "v", intent="inout") + + @xs.process + class In1: + var = xs.foreign(Inout1, "v") + + # io # equivalent to # io-..-io + # in # # in #where any io can be in in deps + with pytest.raises(RuntimeError, match="some input processes do not"): + xs.Model({"io1": Inout1, "in1": In1}, strict_order_check=True) + # io-in #eq. to io-..-io-in + xs.Model( + {"io1": Inout1, "in1": In1}, + strict_order_check=True, + custom_dependencies={"in1": "io1"}, + ) + # in-io #eq to in-io-..-io + xs.Model( + {"io1": Inout1, "in1": In1}, + strict_order_check=True, + custom_dependencies={"io1": "in1"}, + ) + # io + # io + with pytest.raises(RuntimeError, match="has two branch dependencies"): + xs.Model({"io1": Inout1, "io2": Inout2}, strict_order_check=True) + # io-io + xs.Model( + {"io1": Inout1, "io2": Inout2}, + custom_dependencies={"io2": "io1"}, + strict_order_check=True, + ) + # io-io + # / + # in + with pytest.raises(RuntimeError, match="io2 updates it and depends"): + xs.Model( + {"io1": Inout1, "io2": Inout2, "in1": In1}, + custom_dependencies={"io2": ["io1", "in1"]}, + strict_order_check=True, + ) + # io io + # \ / + # in + xs.Model( + {"io1": Inout1, "io2": Inout2, "in1": In1}, + custom_dependencies={"io2": "in1", "in1": "io1"}, + ) + + # the following is a bit arbitrary which raises: 2 or 3 based on the ordering of dicts + # io2-io1 + # / + # io3 This raises in first tree traversal: should be equivalent to + with pytest.raises(RuntimeError, match="has two branch dependencies"): + xs.Model( + {"io1": Inout1, "io2": Inout2, "io3": Inout3}, + custom_dependencies={"io1": ["io2", "io3"]}, + strict_order_check=True, + ) + + # io-io + # \ + # io + with pytest.raises(RuntimeError, match="order of inout process"): + xs.Model( + {"io1": Inout1, "io2": Inout2, "io3": Inout3}, + custom_dependencies={"io1": "io2", "io3": "io2"}, + strict_order_check=True, + ) + + def test_strict_check_multiple_vars_in_process(self): + # in|->|io|->|in|->|io| - foo variable + # | |in|->|io|->|in| - bar variable + @xs.process + class Out: + foo = xs.on_demand() + bar = xs.variable(intent="out") + + @foo.compute + def method(self): + pass + + @xs.process + class FooInBarInout: + foo = xs.foreign(Out, "foo") + bar = xs.foreign(Out, "bar", intent="inout") + + @xs.process + class FooInoutBarIn: + foo = xs.foreign(Out, "foo") + bar = xs.foreign(Out, "bar", intent="inout") + + @xs.process + class FooIn: + foo = xs.foreign(Out, "foo") + + @xs.process + class BarInFooInout: + bar = xs.foreign(Out, "bar") + foo = xs.foreign(Out, "foo", intent="inout") + + xs.Model( + { + "out": Out, + "foo_in_bar_inout": FooInBarInout, + "foo_inout_bar_in": FooInoutBarIn, + "foo_in": FooIn, + "boo_in_foo_inout": BarInFooInout, + }, + custom_dependencies={ + "boo_in_foo_inout": "foo_in_bar_inout", + "foo_in_bar_inout": "foo_inout_bar_in", + "foo_inout_bar_in": "foo_in", + }, + strict_order_check=True, + ) + def test_process_inheritance(self, model): @xs.process class InheritedProfile(Profile): @@ -289,11 +448,63 @@ def test_update_processes(self, no_init_model, model): ) assert m == model + def test_update_processes_strict_check(self): + m = xs.Model({}, strict_order_check=True) + assert m.update_processes({})._strict_order_check == True + assert ( + m.update_processes({}, strict_order_check=False)._strict_order_check + == False + ) + + def test_update_processes_custom(self): + @xs.process + class A: + pass + + @xs.process + class B: + pass + + @xs.process + class C: + pass + + ma = xs.Model({"a": A}) + mab = xs.Model({"a": A, "b": B}, custom_dependencies={"b": "a"}) + mabc = xs.Model({"a": A, "b": B, "c": C}, custom_dependencies={"b": {"a", "c"}}) + + assert ma.update_processes({"b": B}, custom_dependencies={"b": "a"}) == mab + assert mab.update_processes({"c": C}, custom_dependencies={"b": "c"}) == mabc + @pytest.mark.parametrize("p_names", ["add", ["add"]]) def test_drop_processes(self, no_init_model, simple_model, p_names): m = no_init_model.drop_processes(p_names) assert m == simple_model + def test_drop_processes_custom(self): + @xs.process + class A: + pass + + @xs.process + class B: + pass + + @xs.process + class C: + pass + + @xs.process + class D: + pass + + model = xs.Model( + {"a": A, "b": B, "c": C, "d": D}, + custom_dependencies={"d": "c", "c": "b", "b": "a"}, + ) + model = model.drop_processes(["b", "c"]) + assert model.dependent_processes["d"] == ["a"] + def test_visualize(self, model): pytest.importorskip("graphviz") ipydisp = pytest.importorskip("IPython.display") diff --git a/xsimlab/tests/test_variable.py b/xsimlab/tests/test_variable.py index 159edf5..9efb1c9 100644 --- a/xsimlab/tests/test_variable.py +++ b/xsimlab/tests/test_variable.py @@ -104,9 +104,6 @@ class Foo: def test_foreign(): - with pytest.raises(ValueError, match="intent='inout' is not supported.*"): - xs.foreign(ExampleProcess, "some_var", intent="inout") - var = attr.fields(ExampleProcess).out_foreign_var ref_var = attr.fields(AnotherProcess).another_var diff --git a/xsimlab/variable.py b/xsimlab/variable.py index d6b642a..e246cf4 100644 --- a/xsimlab/variable.py +++ b/xsimlab/variable.py @@ -445,9 +445,6 @@ def foreign(other_process_cls, var_name, intent="in"): model. """ - if intent == "inout": - raise ValueError("intent='inout' is not supported for foreign variables") - ref_var = attr.fields_dict(other_process_cls)[var_name] metadata = {