Skip to content

Commit

Permalink
Fix misinterpreted tuples passed as allowed_dims for Variable (#17)
Browse files Browse the repository at this point in the history
* add regression test (>1-length tuple passed to allowed_dims)

* fix >1-length tuples passed to Variable.allowed_dims

* fix broken tests
  • Loading branch information
benbovy committed Nov 7, 2017
1 parent 8e4ea39 commit c46e3fc
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion xsimlab/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class ExampleProcess(Process):
"""A full example of process interface.
"""
var = Variable((), provided=True)
var_list = VariableList([Variable('x'), Variable(((), 'x'))])
var_list = VariableList([Variable('x'), Variable([(), 'x'])])
var_group = VariableGroup('group')
no_var = 'this is not a variable object'

Expand Down
9 changes: 6 additions & 3 deletions xsimlab/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@ def test_constructor(self):
var = Variable(allowed_dims)
assert var.allowed_dims == ((),)

for allowed_dims in ('x', ['x'], tuple(['x'])):
for allowed_dims in ('x', ['x'], ('x')):
var = Variable(allowed_dims)
assert var.allowed_dims == (('x',),)

var = Variable(('x', 'y'))
assert var.allowed_dims == (('x', 'y'),)

var = Variable([(), 'x', ('x', 'y')])
assert var.allowed_dims, ((), ('x',), ('x', 'y'))
assert var.allowed_dims == ((), ('x',), ('x', 'y'))

def test_validators(self):
# verify default validators + user supplied validators
Expand Down Expand Up @@ -74,7 +77,7 @@ def test_to_xarray_variable(self):
xr.testing.assert_identical(xr_var, expected_xr_var)

def test_repr(self):
var = Variable(((), 'x', ('x', 'y')))
var = Variable([(), 'x', ('x', 'y')])
expected_repr = "<xsimlab.Variable (), ('x'), ('x', 'y')>"
assert repr(var) == expected_repr

Expand Down
4 changes: 2 additions & 2 deletions xsimlab/variable/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@ def __init__(self, allowed_dims, provided=False, optional=False,

if not len(allowed_dims):
allowed_dims = [()]
if isinstance(allowed_dims, str):
elif isinstance(allowed_dims, str):
allowed_dims = [(allowed_dims,)]
elif isinstance(allowed_dims, list):
allowed_dims = [tuple([d]) if isinstance(d, str) else tuple(d)
for d in allowed_dims]
elif len(allowed_dims) == 1:
else:
allowed_dims = [allowed_dims]
self.allowed_dims = tuple(allowed_dims)

Expand Down

0 comments on commit c46e3fc

Please sign in to comment.