Skip to content

Commit

Permalink
better test structure; fix segmentation fault in CI due to two names …
Browse files Browse the repository at this point in the history
…of one method which caused typeguard trouble
  • Loading branch information
s-m-e committed Dec 29, 2022
1 parent 3d13244 commit 958b87b
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions tests/test_customtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,14 @@


@pytest.mark.parametrize("arch,conv,ctypes,dll_handle", get_context(__file__))
def test_customtype(arch, conv, ctypes, dll_handle):
@pytest.mark.parametrize("data", [
[1, 2, 3, 4],
[1.0, 2.0, 3.0, 4.0],
(1.0, 2.0, 3.0, 4.0),
np.array([1.0, 2.0, 3.0, 4.0], dtype = 'f8'),
array('d', [1.0, 2.0, 3.0, 4.0]),
])
def test_customtype(data, arch, conv, ctypes, dll_handle):
"""
Test basic handling of custom ctypes data types
"""
Expand All @@ -87,6 +94,7 @@ def from_param(self, param: Any) -> Any:
Called by ctypes/zugbruecke, dispatches to different implementations
"""

param = list(param) if isinstance(param, tuple) else param
typename = type(param).__name__

if hasattr(self, "from_" + typename):
Expand All @@ -100,7 +108,7 @@ def from_array(self, param: array) -> Any:
"""

if param.typecode != "d":
raise TypeError("must be an array of doubles")
raise TypeError("must be an array of type double")
ptr, _ = param.buffer_info()
return ctypes.cast(
ptr,
Expand All @@ -117,13 +125,13 @@ def from_list(self, param: Union[List[float], Tuple[float, ...]]) -> Any:
ctypes.POINTER(ctypes.c_double),
)

from_tuple = from_list

def from_ndarray(self, param: np.ndarray) -> Any:
"""
Implementation for numpy.ndarray
"""

if param.dtype != np.float64:
raise TypeError("must be an ndarray of dtype doubles")
return param.ctypes.data_as(
ctypes.POINTER(ctypes.c_double)
)
Expand All @@ -141,8 +149,4 @@ def from_ndarray(self, param: np.ndarray) -> Any:
avg_dll.argtypes = (DoubleArray, ctypes.c_int)
avg_dll.restype = ctypes.c_double

assert pytest.approx(2.5, 0.0000001) == avg_dll([1, 2, 3, 4], 4)
assert pytest.approx(2.5, 0.0000001) == avg_dll([1.0, 2.0, 3.0, 4.0], 4)
assert pytest.approx(2.5, 0.0000001) == avg_dll((1.0, 2.0, 3.0, 4.0), 4)
assert pytest.approx(2.5, 0.0000001) == avg_dll(np.array([1.0, 2.0, 3.0, 4.0], dtype = 'f8'), 4)
assert pytest.approx(2.5, 0.0000001) == avg_dll(array('d', [1.0, 2.0, 3.0, 4.0]), 4)
assert pytest.approx(2.5, 0.0000001) == avg_dll(data, 4)

0 comments on commit 958b87b

Please sign in to comment.