Skip to content

Commit

Permalink
Fix a bug in Variable._update_context when type was missing in a comp…
Browse files Browse the repository at this point in the history
…osing (further in the sequence) variable.
  • Loading branch information
ynikitenko committed Aug 24, 2023
1 parent c0b1aa5 commit 9175baa
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 3 deletions.
18 changes: 15 additions & 3 deletions lena/variables/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,13 @@ def _update_context(context, var_context):
cvar = context.get("variable")
# preserve variable composition information if that is present
composed = []
if cvar and ("type" in var_context) and ("type" in cvar):
if cvar and ("type" in cvar):
# If cvar has no "type",
# then no types were in the recent variable or earlier
cur_type = var_context["type"]
if "type" in var_context:
cur_type = var_context["type"]
else:
cur_type = []
if "compose" in cvar:
assert isinstance(cvar["compose"], list)
else:
Expand All @@ -200,7 +203,8 @@ def _update_context(context, var_context):
assert isinstance(var_context["compose"], list)
cvar["compose"].extend(cur_type)
else:
cvar["compose"].append(cur_type)
if cur_type:
cvar["compose"].append(cur_type)
composed = cvar["compose"]

old_cvar = context.get("variable", {})
Expand Down Expand Up @@ -276,7 +280,15 @@ def __init__(self, *args, **kwargs):
name = kwargs.pop("name", None)
if name is None:
name = "_".join([var.name for var in self._vars])

var_context = {}
# we don't preserve types of combined variables,
# because they will take too much space in context (visually).
# types = [var.var_context.get("type", None) for var in args]
# type1 = types[0]
# # preserve type it is same for all variables
# if all((tp == type1 for tp in types[1:])):
# var_context["type"] = type1
var_context.update(kwargs)
assert "dim" not in kwargs # to set it manually is meaningless
var_context.update({"dim": self.dim})
Expand Down
32 changes: 32 additions & 0 deletions tests/variables/test_variable.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import pytest
from copy import deepcopy

import lena.core
import lena.context
Expand All @@ -8,6 +9,10 @@
from lena.variables.variable import Combine, Compose, Variable


# "double events"
dev_data = [((1, 2.2, 3), (1.5, 2., 3))]


def test_combine():
## init and getitem work
# must have arguments
Expand Down Expand Up @@ -45,6 +50,7 @@ def test_combine():
results = map(c, data)
assert c.name == "mm_m"
assert [res[0] for res in results] == [(10,0.01), (20,0.02), (30,0.03)]

# combination of Combines works
c = Combine(mm, Combine(mm, m))
results = [c(dt) for dt in data]
Expand Down Expand Up @@ -104,6 +110,32 @@ def test_combine():
}
}

# composition of variables works with Combine
positron = Variable("positron", getter=lambda dev: dev[0], type="particle")
x = Variable("x", getter=lambda coord: coord[0], type="coordinate")
y = Variable("y", getter=lambda coord: coord[1], type="coordinate")
xy = Combine(x, y)
seq = Sequence(positron, xy)
res = list(seq.run(deepcopy(dev_data)))
assert len(res) == 1
assert res[0][0] == (1, 2.2)
assert res[0][1] == {
'variable': {
'combine': (
{
'coordinate': {'name': 'x'}, 'name': 'x', 'type': 'coordinate'
},
{
'coordinate': {'name': 'y'}, 'name': 'y', 'type': 'coordinate'
}
),
'compose': ['particle'],
'dim': 2,
'name': 'x_y',
'particle': {'name': 'positron'},
}
}


def test_compose():
data = [((1.05, 0.98, 0.8), (1.1, 1.1, 1.3))]
Expand Down

0 comments on commit 9175baa

Please sign in to comment.