Skip to content

Commit

Permalink
variables.Combine now creates a range attribute if all its variables …
Browse files Browse the repository at this point in the history
…have range.
  • Loading branch information
ynikitenko committed Oct 31, 2020
1 parent dcd2568 commit ef49a75
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
11 changes: 11 additions & 0 deletions lena/variables/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,14 @@ def __init__(self, *args, **kwargs):
*dim* is the number of variables.
*range*. If all variables have an attribute *range*,
the *range* of this variable is set to a list of them.
All *args* must be *Variables*
and there must be at least one of them,
otherwise :class:`LenaTypeError` is raised.
"""
# set _vars, dim and getter.
if not args:
raise lena.core.LenaTypeError(
"Combine must be initialized with 1 or more variables"
Expand All @@ -211,6 +215,7 @@ def __init__(self, *args, **kwargs):
self.dim = len(args)
getter = lambda val: tuple(var.getter(val) for var in self._vars)

# update var_context with name and kwargs.
name = kwargs.pop("name", None)
if name is None:
name = "_".join([var.name for var in self._vars])
Expand All @@ -221,6 +226,12 @@ def __init__(self, *args, **kwargs):
var_context["combine"] = tuple(
copy.deepcopy(var.var_context) for var in self._vars
)

# set range of the combined variables
if all(hasattr(var, "range") for var in self._vars):
range_ = [var.range for var in self._vars]
var_context["range"] = range_

super(Combine, self).__init__(name=name, getter=getter, **var_context)

def __getitem__(self, index):
Expand Down
16 changes: 15 additions & 1 deletion tests/variables/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,28 @@ def test_combine():
Combine("xy")
with pytest.raises(LenaTypeError):
Combine(lambda x: x)
mm = Variable("mm", unit="mm", getter=lambda x: x*10, type="coordinate")
mm = Variable("mm", unit="mm", getter=lambda x: x*10, type="coordinate", range=[0, 100])
c = Combine(mm, name="xy")
assert c[0] == mm

# range creation works
assert c.range == [mm.range]
cm = Variable("cm", unit="cm", getter=lambda x: x, type="coordinate", range=[0, 10])
c2 = Combine(mm, cm)
assert c2.range == [mm.range, cm.range]
# same explicitly
assert c2.range == [[0, 100], [0, 10]]
# has no range
eV = Variable("cm", unit="cm", getter=lambda x: x, type="coordinate")
c3 = Combine(mm, eV)
with pytest.raises(lena.core.LenaAttributeError):
c3.range

## __call__ works
data = [1, 2, 3]
results = map(c, data)
assert [res[0] for res in results] == [(10,), (20,), (30,)]
mm = Variable("mm", unit="mm", getter=lambda x: x*10, type="coordinate")
m = Variable("m", unit="m", getter=lambda x: x/100., type="coordinate")
c = Combine(mm, m)
results = map(c, data)
Expand Down

0 comments on commit ef49a75

Please sign in to comment.