Skip to content

Commit

Permalink
Ensure no variable intent conflict when building a model (#57)
Browse files Browse the repository at this point in the history
* ensure no variable intent conflict when building a model

* fix missing import

* update what's new
  • Loading branch information
benbovy committed Sep 23, 2019
1 parent 86729d4 commit d08c334
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 1 deletion.
7 changes: 7 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ Release Notes
v0.3.0 (Unreleased)
-------------------

Enhancements
~~~~~~~~~~~~

- Ensure that there is no ``intent`` conflict between the variables
declared in a model. This check is explicit at Model creation and a
more meaningful error message is shown when it fails (:issue:`57`).


v0.2.1 (7 November 2018)
------------------------
Expand Down
30 changes: 30 additions & 0 deletions xsimlab/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,34 @@ def set_process_keys(self):
if od_key is not None:
p_obj.__xsimlab_od_keys__[var.name] = od_key

def ensure_no_intent_conflict(self):
"""Raise an error if more than one variable with
intent='out' targets the same variable.
"""
filter_out = lambda var: (
var.metadata['intent'] == VarIntent.OUT and
var.metadata['var_type'] != VarType.ON_DEMAND
)

targets = defaultdict(list)

for p_name, p_obj in self._processes_obj.items():
for var in filter_variables(p_obj, func=filter_out).values():
target_key = p_obj.__xsimlab_store_keys__.get(var.name)
targets[target_key].append((p_name, var.name))

conflicts = {k: v for k, v in targets.items() if len(v) > 1}

if conflicts:
conflicts_str = {k: ' and '.join(["'{}.{}'".format(*i) for i in v])
for k, v in conflicts.items()}
msg = '\n'.join(["'{}.{}' set by: {}".format(*k, v)
for k, v in conflicts_str.items()])

raise ValueError(
"Conflict(s) found in given variable intents:\n" + msg)

def get_all_variables(self):
"""Get all variables in the model as a list of
``(process_name, var_name)`` tuples.
Expand Down Expand Up @@ -364,6 +392,8 @@ def __init__(self, processes):
self._all_vars = builder.get_all_variables()
self._all_vars_dict = None

builder.ensure_no_intent_conflict()

self._input_vars = builder.get_input_variables()
self._input_vars_dict = None

Expand Down
11 changes: 10 additions & 1 deletion xsimlab/tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

import xsimlab as xs
from xsimlab.tests.fixture_model import AddOnDemand, InitProfile
from xsimlab.tests.fixture_model import AddOnDemand, InitProfile, Profile


class TestModelBuilder(object):
Expand Down Expand Up @@ -52,6 +52,15 @@ def test_get_all_variables(self, model):
assert all([p_name in model for p_name, _ in model.all_vars])
assert ('profile', 'u') in model.all_vars

def test_ensure_no_intent_conflict(self, model):
@xs.process
class Foo(object):
u = xs.foreign(Profile, 'u', intent='out')

with pytest.raises(ValueError) as excinfo:
invalid_model = model.update_processes({'foo': Foo})
assert "Conflict(s)" in str(excinfo.value)

def test_get_input_variables(self, model):
expected = {('init_profile', 'n_points'),
('roll', 'shift'),
Expand Down

0 comments on commit d08c334

Please sign in to comment.