Skip to content

Commit

Permalink
Fix a bug in variables.Combine.var_context. Add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
ynikitenko committed Sep 8, 2020
1 parent af103c2 commit f0168fb
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 9 deletions.
18 changes: 9 additions & 9 deletions lena/variables/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,11 @@ def __init__(self, name, getter, **kwargs):
raise lena.core.LenaTypeError(
"a callable getter must be provided, {} given".format(getter)
)
# getter is public for possible performance implementations
# (without context)
self.getter = getter

# var_context is public, so that one can get all attributes
self.var_context = {
"name": self.name,
}
Expand All @@ -131,7 +134,6 @@ def __call__(self, value):
# context = copy.deepcopy(context)
var_context = context.get("variable")
if var_context:
# todo: check that several compose context work
# deep copy, otherwise it will be updated during update_recursively
context["variable"]["compose"] = copy.deepcopy(var_context)
# update recursively, because we need to preserve "type"
Expand All @@ -150,12 +152,11 @@ def __getattr__(self, name):
if name.startswith('_'):
raise AttributeError(name)
try:
attr = self.var_context[name]
return self.var_context[name]
except KeyError:
raise lena.core.LenaAttributeError(
"{} missing in {}".format(name, self.name)
)
return attr

def get(self, key, default=None):
"""Return the attribute *key* if present, else default.
Expand Down Expand Up @@ -213,15 +214,14 @@ def __init__(self, *args, **kwargs):
name = kwargs.pop("name", None)
if name is None:
name = "_".join([var.name for var in self._vars])
var_context = {
"name": name,
"dim": self.dim,
"getter": getter,
}
var_context = {}
var_context.update(kwargs)
assert "dim" not in kwargs # to set it manually is meaningless
var_context.update({"dim": self.dim})
var_context["combine"] = tuple(
copy.deepcopy(var.var_context) for var in self._vars
)
super(Combine, self).__init__(**var_context)
super(Combine, self).__init__(name=name, getter=getter, **var_context)

def __getitem__(self, index):
"""Get variable at the given *index*."""
Expand Down
27 changes: 27 additions & 0 deletions tests/variables/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,30 @@ def test_variable():
}
}
)


def test_getattr_and_var_context():
x_mm = Variable("x", unit="mm", getter=lambda x: x*10, type="coordinate")
y_mm = Variable("y", unit="mm", getter=lambda x: x*10, type="coordinate")

# Variable attribute works
assert x_mm.type == "coordinate"

# getter should not be in var_context
assert "getter" not in x_mm.var_context

# Combine attribute works
combine1 = Combine(x_mm, name="xy")
assert combine1.name == "xy"

xy_range = [(0,1), (0,1)]
combine2 = Combine(x_mm, y_mm, name="xy", range=xy_range)
assert combine2.range == xy_range

# name and getter should not be in var_context
assert "getter" not in combine1.var_context

# Compose attribute works
compose = Compose(x_mm, y_mm, name="xy", type="coordinate")
assert compose.type == "coordinate"
assert "getter" not in compose.var_context

0 comments on commit f0168fb

Please sign in to comment.