Skip to content

Commit

Permalink
wip tests (model)
Browse files Browse the repository at this point in the history
  • Loading branch information
benbovy committed Jul 10, 2017
1 parent 6f331ee commit 4cfc9e8
Showing 1 changed file with 106 additions and 34 deletions.
140 changes: 106 additions & 34 deletions xsimlab/tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from textwrap import dedent

import numpy as np
from numpy.testing import assert_array_equal
Expand All @@ -10,7 +11,7 @@


class Grid(Process):
x_size = Variable((), optional=True)
x_size = Variable((), optional=True, description='grid size')
x = Variable('x', provided=True)

class Meta:
Expand All @@ -25,7 +26,7 @@ def initialize(self):


class Quantity(Process):
quantity = Variable('x')
quantity = Variable('x', description='a quantity')
all_effects = VariableGroup('effect')

def run_step(self, *args):
Expand All @@ -36,46 +37,58 @@ def finalize_step(self):


class SomeProcess(Process):
some_param = Variable(())
some_param = Variable((), description='some parameter')
copy_param = Variable((), provided=True)
x = ForeignVariable(Grid, 'x')
quantity = ForeignVariable(Quantity, 'quantity')
some_effect = Variable('x', group='effect', provided=True)

def initialize(self):
self.copy_param.value = self.some_param.value

def run_step(self, dt):
self.some_effect.value = self.x.value * self.some_param.value + dt

def finalize(self):
self.some_effect.rate = 0


class OtherProcess(Process):
other_param = Variable(())
x = ForeignVariable(Grid, 'x')
copy_param = ForeignVariable(SomeProcess, 'copy_param')
quantity = ForeignVariable(Quantity, 'quantity')
other_effect = Variable('x', group='effect', provided=True)

def run_step(self, dt):
self.other_effect.value = self.x.value * self.other_param.value - dt
self.other_effect.value = self.x.value * self.copy_param.value - dt

@diagnostic
def x2(self):
return self.x * 2


class MetaProcess(Process):
class PlugProcess(Process):
meta_param = Variable(())
some_param = ForeignVariable(SomeProcess, 'some_param')
other_param = ForeignVariable(OtherProcess, 'other_param')
some_param = ForeignVariable(SomeProcess, 'some_param', provided=True)

def run_step(self, *args):
self.some_param.value = self.param.value
self.other_param.value = self.param.value
self.some_param.value = self.meta_param.value


class TestModel(unittest.TestCase):
def get_test_model():
model = Model({'grid': Grid,
'some_process': SomeProcess,
'other_process': OtherProcess,
'quantity': Quantity})

def setUp(self):
self.model = Model({'grid': Grid,
'some_process': SomeProcess,
'other_process': OtherProcess,
'quantity': Quantity})
model.grid.x_size.value = 10
model.quantity.quantity.state = np.zeros(10)
model.some_process.some_param.value = 1

return model


class TestModel(unittest.TestCase):

def test_constructor(self):
# test invalid processes
Expand All @@ -92,43 +105,102 @@ class OtherClass(object):
Model({'invalid_class': OtherClass})

# test process ordering
sorted_process_names = list(self.model.keys())
self.assertEqual(sorted_process_names[0], 'grid')
self.assertEqual(sorted_process_names[-1], 'quantity')
self.assertIn('some_process', sorted_process_names[1:-1])
self.assertIn('other_process', sorted_process_names[1:-1])
model = get_test_model()
expected = ['grid', 'some_process', 'other_process', 'quantity']
self.assertEqual(list(model), expected)

# test dict-like vs. attribute access
self.assertIs(self.model['grid'], self.model.grid)
self.assertIs(model['grid'], model.grid)

def test_input_vars(self):
model = get_test_model()
expected = {'grid': ['x_size'],
'some_process': ['some_param'],
'other_process': ['other_param'],
'quantity': ['quantity']}
actual = {k: list(v.keys()) for k, v in self.model.input_vars.items()}
actual = {k: list(v.keys()) for k, v in model.input_vars.items()}
self.assertDictEqual(expected, actual)

def test_is_input(self):
self.assertTrue(self.model.is_input(self.model.grid.x_size))
self.assertTrue(self.model.is_input(('grid', 'x_size')))
self.assertFalse(self.model.is_input(('quantity', 'all_effects')))
model = get_test_model()
self.assertTrue(model.is_input(model.grid.x_size))
self.assertTrue(model.is_input(('grid', 'x_size')))
self.assertFalse(model.is_input(('quantity', 'all_effects')))

external_variable = Variable(())
self.assertFalse(model.is_input(external_variable))

def test_initialize(self):
model = self.model.clone()
model.grid.x_size.value = 10
model = get_test_model()
model.initialize()
expected = np.arange(10)
assert_array_equal(model.grid.x.value, expected)

def test_run_step(self):
model = self.model.clone()
model.grid.x_size.value = 10
model.some_process.some_param.value = 1
model.other_process.other_param.value = 1

model = get_test_model()
model.initialize()
model.run_step(100)

expected = model.grid.x.value * 2
assert_array_equal(model.quantity.quantity.change, expected)

def test_finalize_step(self):
model = get_test_model()
model.initialize()
model.run_step(100)
model.finalize_step()

expected = model.grid.x.value * 2
assert_array_equal(model.quantity.quantity.state, expected)

def test_finalize(self):
model = get_test_model()
model.finalize()
self.assertEqual(model.some_process.some_effect.rate, 0)

def test_clone(self):
model = get_test_model()
cloned = model.clone()

for (ck, cp), (k, p) in zip(cloned.items(), model.items()):
self.assertEqual(ck, k)
self.assertIsNot(cp, p)

def test_update_processes(self):
model = get_test_model()
expected = Model({'grid': Grid,
'plug_process': PlugProcess,
'some_process': SomeProcess,
'other_process': OtherProcess,
'quantity': Quantity})
actual = model.update_processes({'plug_process': PlugProcess})

# TODO: more advanced (public?) test function to compare two models?
self.assertEqual(list(actual), list(expected))

def test_drop_processes(self):
model = get_test_model()

expected = Model({'grid': Grid,
'some_process': SomeProcess,
'quantity': Quantity})
actual = model.drop_processes('other_process')
self.assertEqual(list(actual), list(expected))

expected = Model({'grid': Grid,
'quantity': Quantity})
actual = model.drop_processes(['some_process', 'other_process'])
self.assertEqual(list(actual), list(expected))

def test_repr(self):
model = get_test_model()
expected = dedent("""\
<xsimlab.Model (4 processes, 3 inputs)>
grid
x_size (in) grid size
some_process
some_param (in) some parameter
other_process
quantity
quantity (in) a quantity""")

self.assertEqual(repr(model), expected)

0 comments on commit 4cfc9e8

Please sign in to comment.