Skip to content

Commit

Permalink
Merge pull request #1133 from pyiron/error
Browse files Browse the repository at this point in the history
Add numpy equality check for unit tests
  • Loading branch information
niklassiemer committed Jun 15, 2023
2 parents 721f3d6 + 021f1d2 commit b6bf320
Show file tree
Hide file tree
Showing 12 changed files with 65 additions and 47 deletions.
18 changes: 17 additions & 1 deletion pyiron_base/_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pyiron_base.project.generic import Project
from abc import ABC
from inspect import getfile
import numpy as np


__author__ = "Liam Huber"
Expand All @@ -30,9 +31,24 @@
class PyironTestCase(unittest.TestCase, ABC):

"""
Tests that also include testing the docstrings in the specified modules
Base class for all pyiron unit tets.
Registers utility type equality functions:
- np.testing.assert_array_equal
Optionally includes testing the docstrings in the specified module by
overloading :attr:`~.docstring_module`.
"""

def setUp(self):
self.addTypeEqualityFunc(np.ndarray, self._assert_equal_numpy)

def _assert_equal_numpy(self, a, b, msg=None):
try:
np.testing.assert_array_equal(a, b, err_msg=msg)
except AssertionError as e:
raise self.failureException(*e.args) from None

@classmethod
def setUpClass(cls):
cls._initial_settings_configuration = state.settings.configuration.copy()
Expand Down
1 change: 1 addition & 0 deletions tests/database/test_filetable.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class TestFileTable(PyironTestCase):
# touching them right now. -Liam Huber

def setUp(self) -> None:
super().setUp()
here = dirname(abspath(__file__))
self.loc1 = join(here, "ft_test_loc1")
self.loc2 = join(here, "ft_test_loc2")
Expand Down
1 change: 1 addition & 0 deletions tests/generic/test_datacontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def setUpClass(cls):
cls.pl["tail"] = DataContainer([2, 4, 8])

def setUp(self):
super().setUp()
self.hdf = self.project.create_hdf(self.project.path, "test")

def tearDown(self):
Expand Down
30 changes: 12 additions & 18 deletions tests/generic/test_fileHDFio.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,29 @@ def _write_full_hdf_content(hdf):
def _check_full_hdf_values(self, hdf, group="content"):
with self.subTest(group + "/array"):
array = hdf[group + "/array"]
self.assertTrue(np.array_equal(array, np.array([1, 2, 3, 4, 5, 6])))
self.assertEqual(array, np.array([1, 2, 3, 4, 5, 6]))
self.assertIsInstance(array, np.ndarray)
self.assertEqual(array.dtype, np.dtype(int))

with self.subTest(group + "/array_3d"):
array = hdf[group]["array_3d"]
self.assertTrue(
np.array_equal(
self.assertEqual(
array,
np.array([[1, 2, 3], [4, 5, 6]]),
)
)
self.assertIsInstance(array, np.ndarray)
self.assertEqual(array.dtype, np.dtype(int))

with self.subTest(group + "/indices"):
array = hdf[group + "/indices"]
self.assertTrue(np.array_equal(array, np.array([1, 1, 1, 1, 6])))
self.assertEqual(array, np.array([1, 1, 1, 1, 6]))
self.assertIsInstance(array, np.ndarray)
self.assertEqual(array.dtype, np.dtype(int))

with self.subTest(group + "/traj"):
array = hdf[group + "/traj"]
self.assertTrue(np.array_equal(array[0], np.array([[1, 2, 3], [4, 5, 6]])))
self.assertTrue(np.array_equal(array[1], np.array([[7, 8, 9]])))
self.assertEqual(array[0], np.array([[1, 2, 3], [4, 5, 6]]))
self.assertEqual(array[1], np.array([[7, 8, 9]]))
self.assertIsInstance(array, np.ndarray)
self.assertEqual(array.dtype, np.dtype(object))

Expand All @@ -68,11 +66,9 @@ def _check_full_hdf_values(self, hdf, group="content"):
with self.subTest(group + "/dict_numpy"):
content_dict = hdf[group + "/dict_numpy"]
self.assertEqual(content_dict["key_1"], 1)
self.assertTrue(
np.array_equal(
content_dict["key_2"],
np.array([1, 2, 3, 4, 5, 6]),
)
self.assertEqual(
content_dict["key_2"],
np.array([1, 2, 3, 4, 5, 6]),
)

with self.subTest(group + "/group/some_entry"):
Expand Down Expand Up @@ -159,12 +155,12 @@ def test__convert_dtype_obj_array(self):

with self.assertLogs(state.logger):
array = self.i_o_hdf5._convert_dtype_obj_array(int_array_as_objects_array)
self.assertTrue(np.array_equal(array, int_array_as_objects_array))
self.assertEqual(array, int_array_as_objects_array)
self.assertEqual(array.dtype, np.dtype(int))

with self.assertLogs(state.logger):
array = self.i_o_hdf5._convert_dtype_obj_array(float_array_as_objects_array)
self.assertTrue(np.array_equal(array, float_array_as_objects_array))
self.assertEqual(array, float_array_as_objects_array)
self.assertEqual(array.dtype, np.dtype(float))

def test_array_type_conversion(self):
Expand Down Expand Up @@ -315,12 +311,10 @@ def test_get(self):
42,
"default value not returned when value doesn't exist.",
)
self.assertTrue(
np.array_equal(
self.assertEqual(
self.full_hdf5.get("content/array", default=42),
np.array([1, 2, 3, 4, 5, 6]),
),
"default value returned when value does exist.",
"default value returned when value does exist.",
)
with self.assertRaises(ValueError):
self.empty_hdf5.get("non_existing_key")
Expand Down
51 changes: 26 additions & 25 deletions tests/generic/test_flattenedstorage.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,11 @@ def test_add_chunk_add_array(self):
except ValueError:
# both checks below are regression tests for https://github.com/pyiron/pyiron_contrib/pull/197
self.fail("add_chunk should not raise an exception when passed a value for an existing per-chunk array.")
self.assertTrue(np.array_equal(val, cont.get_array("perchunk", 0)),
self.assertEqual(val, cont.get_array("perchunk", 0),
"add_chunk did not remove first axis on a per chunk array!")
# test the same, but now let the array be created by add_chunk, instead of doing it on our own
cont.add_chunk(2, perelem=[1,1], perchunk2=val[np.newaxis, :])
self.assertTrue(np.array_equal(val, cont.get_array("perchunk2", 1)),
self.assertEqual(val, cont.get_array("perchunk2", 1),
"add_chunk did not remove first axis on a per chunk array!")


Expand All @@ -173,22 +173,22 @@ def test_get_array(self):
for n, e, o in zip( ("first", None, "third"), self.even, self.odd):
store.add_chunk(len(e), identifier=n, even=e, odd=o, sum=sum(e + o))

self.assertTrue(np.array_equal(store.get_array("even", 0), self.even[0]),
self.assertEqual(store.get_array("even", 0), np.array(self.even[0]),
"get_array returns wrong array for numeric index!")

self.assertTrue(np.array_equal(store.get_array("even", "first"), self.even[0]),
self.assertEqual(store.get_array("even", "first"), np.array(self.even[0]),
"get_array returns wrong array for string identifier!")

self.assertTrue(np.array_equal(store.get_array("even", "1"), self.even[1]),
self.assertEqual(store.get_array("even", "1"), np.array(self.even[1]),
"get_array returns wrong array for automatic identifier!")

self.assertTrue(np.array_equal(store.get_array("sum", 0), sum(self.even[0] + self.odd[0])),
self.assertEqual(store.get_array("sum", 0), sum(self.even[0] + self.odd[0]),
"get_array returns wrong array for numeric index!")

self.assertTrue(np.array_equal(store.get_array("sum", "first"), sum(self.even[0] + self.odd[0])),
self.assertEqual(store.get_array("sum", "first"), sum(self.even[0] + self.odd[0]),
"get_array returns wrong array for string identifier!")

self.assertTrue(np.array_equal(store.get_array("sum", "1"), sum(self.even[1] + self.odd[1])),
self.assertEqual(store.get_array("sum", "1"), sum(self.even[1] + self.odd[1]),
"get_array returns wrong array for automatic identifier!")

with self.assertRaises(KeyError, msg="Non-existing identifier!"):
Expand All @@ -199,10 +199,10 @@ def test_get_array_full(self):

store = FlattenedStorage(elem=[ [1], [2, 3], [4, 5, 6] ], chunk=[-1, -2, -3])
elem = store.get_array("elem")
self.assertTrue(np.array_equal(elem, [1, 2, 3, 4, 5, 6]),
self.assertEqual(elem, np.array([1, 2, 3, 4, 5, 6]),
f"get_array return did not return correct flat array, but {elem}.")
chunk = store.get_array("chunk")
self.assertTrue(np.array_equal(chunk, [-1, -2, -3]),
self.assertEqual(chunk, np.array([-1, -2, -3]),
f"get_array return did not return correct flat array, but {chunk}.")

def test_get_array_filled(self):
Expand All @@ -224,16 +224,17 @@ def test_get_array_filled(self):
])
val = store.get_array_filled("elem")
self.assertEqual(val.shape, (3, 3), "shape not correct!")
self.assertTrue(np.array_equal(val, [[1, -1, -1], [2, 3, -1], [4, 5, 6]]),
self.assertEqual(val, np.array([[1, -1, -1], [2, 3, -1], [4, 5, 6]]),
"values in returned array not the same as in original array!")
self.assertEqual(store.get_array_filled("fill")[0, 1], 23.42,
"incorrect fill value!")
val = store.get_array_filled("complex")
self.assertEqual(val.shape, (3, 3, 3), "shape not correct!")
self.assertTrue(np.array_equal(
store.get_array("chunk"),
store.get_array_filled("chunk"),
), "get_array_filled does not give same result as get_array for per chunk array")
self.assertEqual(
store.get_array("chunk"),
store.get_array_filled("chunk"),
"get_array_filled does not give same result as get_array for per chunk array"
)

def test_get_array_ragged(self):
"""get_array_ragged should return a raggend array of all elements in the storage."""
Expand All @@ -244,12 +245,12 @@ def test_get_array_ragged(self):
for i, v in enumerate(val):
self.assertEqual(len(v), store._per_chunk_arrays["length"][i],
f"array {i} has incorrect length!")
self.assertTrue(np.array_equal(v, [[1], [2, 3], [4, 5, 6]][i]),
self.assertEqual(v, np.array([[1], [2, 3], [4, 5, 6]][i]),
f"array {i} has incorrect values, {v}!")
self.assertTrue(np.array_equal(
self.assertEqual(
store.get_array("chunk"),
store.get_array_ragged("chunk"),
), "get_array_ragged does not give same result as get_array for per chunk array")
), "get_array_ragged does not give same result as get_array for per chunk array"

def test_get_array_ragged_dtype_stability(self):
"""get_array_ragged should (only!) convert top-most dimension to dtype=object and be of shape (n,) """
Expand Down Expand Up @@ -402,22 +403,22 @@ def test_getitem_setitem(self):
"""Using __getitem__/__setitem__ should be equivalent to using get_array/set_array."""
store = FlattenedStorage(even=self.even, odd=self.odd, mylen=[1, 2, 3])
for i in range(len(store)):
self.assertTrue(np.array_equal(
self.assertEqual(
store["even", i], store.get_array("even", i),
), f"getitem returned different value ({store['even', i]}) than get_array ({store.get_array('even', i)}) for chunk {i}"
f"getitem returned different value ({store['even', i]}) than get_array ({store.get_array('even', i)}) for chunk {i}"
)
self.assertEqual(store["mylen", i], store.get_array("mylen", i),
f"getitem returned different value ({store['mylen', i]}) than get_array ({store.get_array('mylen', i)}) for chunk {i}")
self.assertTrue(np.array_equal(store["even"], store.get_array("even")),
self.assertEqual(store["even"], store.get_array("even"),
f"getitem returned different value ({store['even']}) than get_array ({store.get_array('even')})")
self.assertTrue(np.array_equal(store["mylen"], store.get_array("mylen")),
self.assertEqual(store["mylen"], store.get_array("mylen"),
f"getitem returned different value ({store['mylen']}) than get_array ({store.get_array('mylen')})")
store["even", 0] = [4]
store["even", 1] = [2, 0]
store["mylen", 0] = 4
self.assertEqual(store.get_array("mylen", 0), 4, "setitem did not set item correctly.")
self.assertTrue(np.array_equal(store.get_array("even", 0), [4]), "setitem did not set item correctly.")
self.assertTrue(np.array_equal(store.get_array("even", 1), [2, 0]), "setitem did not set item correctly.")
self.assertEqual(store.get_array("even", 0), [4], "setitem did not set item correctly.")
self.assertEqual(store.get_array("even", 1), np.array([2, 0]), "setitem did not set item correctly.")

with self.assertRaises(IndexError, msg="Calling setitem with out index doesn't raise Error!"):
store["mylen"] = [1,2,3]
Expand All @@ -443,7 +444,7 @@ def test_hdf_chunklength_one(self):
for i in range(5):
store_foo = store.get_array("foo", i)
read_foo = read.get_array("foo", i)
self.assertTrue(np.array_equal(store_foo, read_foo),
self.assertEqual(store_foo, read_foo,
f"per element values not equal after reading from HDF! {store_foo} != {read_foo}")
self.assertEqual(store.get_array("bar", i), read.get_array("bar", i),
"per chunk values not equal after reading from HDF!")
Expand Down
1 change: 1 addition & 0 deletions tests/generic/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def food(bar=None, baz=None):
class TestImportAlarm(PyironTestCase):

def setUp(self):
super().setUp()
self.import_alarm = ImportAlarm()

@self.import_alarm
Expand Down
2 changes: 2 additions & 0 deletions tests/project/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
class TestProjectData(PyironTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.file_location = dirname(abspath(__file__)).replace("\\", "/")
cls.project_name = join(cls.file_location, "test_data")

Expand All @@ -27,6 +28,7 @@ def tearDown(self):
self.project.remove(enable=True)

def setUp(self):
super().setUp()
self.project = Project(self.project_name)
self.data = ProjectData(project=self.project, table_name="data")
self.data.foo = "foo"
Expand Down
2 changes: 2 additions & 0 deletions tests/project/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def tearDownClass(cls):
pass

def setUp(self):
super().setUp()
self.project = Project(self.project_name)

def tearDown(self):
Expand Down Expand Up @@ -183,6 +184,7 @@ def test_symlink(self):

class TestToolRegistration(TestWithProject):
def setUp(self) -> None:
super().setUp()
self.tools = BaseTools(self.project)

def test_registration(self):
Expand Down
1 change: 1 addition & 0 deletions tests/server/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def setUpClass(cls):
queue_adapters.construct_adapters()

def setUp(self) -> None:
super().setUp()
self.server = Server()
self.server_main = Server(queue="main")

Expand Down
1 change: 1 addition & 0 deletions tests/state/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def tearDownClass(cls) -> None:
s.update()

def setUp(self) -> None:
super().setUp()
try:
self.default_loc.unlink()
except FileNotFoundError:
Expand Down
3 changes: 0 additions & 3 deletions tests/storage/test_has_stored_traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,6 @@ def _drinks_default(self):


class TestInput(TestWithProject):
@classmethod
def setUpClass(cls):
super().setUpClass()

def setUp(self) -> None:
super().setUp()
Expand Down
1 change: 1 addition & 0 deletions tests/table/test_datamining.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def setUpClass(cls):
j.run()

def setUp(self):
super().setUp()
self.table = self.project.create.table('test_table')
self.table.filter_function = lambda j: j.name in ["test_a", "test_b"]
self.table.add['name'] = lambda j: j.name
Expand Down

0 comments on commit b6bf320

Please sign in to comment.