Skip to content

Commit

Permalink
wip tests (variable, utils)
Browse files Browse the repository at this point in the history
  • Loading branch information
benbovy committed Jun 23, 2017
1 parent 6465dd1 commit 3fbb0ed
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 2 deletions.
11 changes: 11 additions & 0 deletions xsimlab/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import unittest

from xsimlab import utils


class TestImportRequired(unittest.TestCase):

def test(self):
err_msg = "no module"
with self.assertRaisesRegex(RuntimeError, err_msg):
utils.import_required('this_module_doesnt_exits', err_msg)
94 changes: 92 additions & 2 deletions xsimlab/tests/test_variable.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import unittest
from collections import OrderedDict

import xarray as xr

from xsimlab.variable.base import (Variable, ForeignVariable, diagnostic,
DiagnosticVariable, ValidationError)
DiagnosticVariable, VariableList,
VariableGroup, ValidationError)
from xsimlab.variable.custom import (NumberVariable, FloatVariable,
IntegerVariable)
from xsimlab.process import Process


class MyProcess(Process):
var = Variable(())
var = Variable((), group='mygroup')

@diagnostic
def diag(self):
Expand All @@ -21,6 +25,14 @@ def diag2(self):
return 2


class MyProcess2(Process):
var = Variable((), group='mygroup')


class MyProcess3(Process):
var = VariableGroup('mygroup')


class TestVariable(unittest.TestCase):

def test_constructor(self):
Expand Down Expand Up @@ -148,3 +160,81 @@ def test_repr(self):
expected_repr = "<xsimlab.DiagnosticVariable>"
self.assertEqual(repr(self.process.diag), expected_repr)
self.assertEqual(repr(self.process.diag2), expected_repr)


class TestVariableList(unittest.TestCase):

def test_constructor(self):
with self.assertRaisesRegex(ValueError, "found variables mixed"):
_ = VariableList([2, Variable(())])


class TestVariableGroup(unittest.TestCase):

def test_iter(self):
myprocess = MyProcess()
myprocess2 = MyProcess2()
myprocess3 = MyProcess3()

with self.assertRaisesRegex(ValueError, "cannot retrieve variables"):
_ = list(myprocess3.var)

processes_dict = OrderedDict((('p1', myprocess), ('p2', myprocess2)))
myprocess3.var._set_variables(processes_dict)

expected = [myprocess.var, myprocess2.var]
for var, proc in zip(myprocess3.var, processes_dict.values()):
var._other_process_obj = proc

fvar_list = [var.ref_var for var in myprocess3.var]
self.assertEqual(fvar_list, expected)

def test_repr(self):
myprocess3 = MyProcess3()

expected_repr = "<xsimlab.VariableGroup 'mygroup'>"
self.assertEqual(repr(myprocess3.var), expected_repr)


class TestNumberVariable(unittest.TestCase):

def test_validate(self):
var = NumberVariable((), bounds=(0, 1))
for data in (-1, [-1, 0], [-1, 1], [0, 2], 2):
xr_var = var.to_xarray_variable(data)
with self.assertRaisesRegex(ValidationError, "out of bounds"):
var.validate(xr_var)

for ib in [(True, False), (False, True), (False, False)]:
var = NumberVariable((), bounds=(0, 1), inclusive_bounds=ib)
xr_var = var.to_xarray_variable([0, 1])
with self.assertRaisesRegex(ValidationError, "out of bounds"):
var.validate(xr_var)


class TestFloatVariable(unittest.TestCase):

def test_validators(self):
var = FloatVariable(())

for val in [1, 1.]:
xr_var = xr.Variable((), val)
var.run_validators(xr_var)

xr_var = xr.Variable((), '1')
with self.assertRaisesRegex(ValidationError, "invalid dtype"):
var.run_validators(xr_var)


class TestIntegerVariable(unittest.TestCase):

def test_validators(self):
var = IntegerVariable(())

xr_var = xr.Variable((), 1)
var.run_validators(xr_var)

for val in [1., '1']:
xr_var = xr.Variable((), val)
with self.assertRaisesRegex(ValidationError, "invalid dtype"):
var.run_validators(xr_var)

0 comments on commit 3fbb0ed

Please sign in to comment.