Skip to content

Commit

Permalink
(tests): adds checks for accessors of mat3 type
Browse files Browse the repository at this point in the history
  • Loading branch information
wpumacay committed Jun 27, 2023
1 parent 63fd3c4 commit 7dd0887
Showing 1 changed file with 35 additions and 2 deletions.
37 changes: 35 additions & 2 deletions tests/python/mat3_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import math3d as m3d

from typing import Type, Union, List, cast
from typing import Type, Union, cast

Matrix3Cls = Type[Union[m3d.Matrix3f, m3d.Matrix3d]]
Vector3Cls = Type[Union[m3d.Vector3f, m3d.Vector3d]]
Expand All @@ -16,6 +16,10 @@ def mat3_all_close(mat: Matrix3, mat_np: np.ndarray, epsilon: float = 1e-5) -> b
return np.allclose(cast(np.ndarray, mat), mat_np, atol=epsilon)


def vec3_all_close(vec: Vector3, vec_np: np.ndarray, epsilon: float = 1e-5) -> bool:
return np.allclose(cast(np.ndarray, vec), vec_np, atol=epsilon)


@pytest.mark.parametrize(
"Mat3, Vec3, FloatType",
[
Expand Down Expand Up @@ -56,7 +60,36 @@ def test_numpy_array_constructor(
) -> None:
mat = Mat3(
np.array(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=FloatType
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
dtype=FloatType,
)
)
assert mat3_all_close(mat, np.arange(1, 10).reshape(3, 3).astype(FloatType))


@pytest.mark.parametrize(
"Mat3, Vec3, FloatType",
[
(m3d.Matrix3f, m3d.Vector3f, np.float32),
(m3d.Matrix3d, m3d.Vector3d, np.float64),
],
)
def test_mat3_accessors(Mat3, Vec3, FloatType) -> None:
mat = Mat3(
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=FloatType)
)

# __getitem__ by using a single entry should return the requested column
col0, col1, col2 = mat[0], mat[1], mat[2]
assert type(col0) == Vec3 and type(col1) == Vec3 and type(col2) == Vec3
assert vec3_all_close(col0, np.array([1.0, 4.0, 7.0], dtype=FloatType))
assert vec3_all_close(col1, np.array([2.0, 5.0, 8.0], dtype=FloatType))
assert vec3_all_close(col2, np.array([3.0, 6.0, 9.0], dtype=FloatType))

# __getitem__ by using a tuple
assert mat[0, 0] == 1.0 and mat[0, 1] == 2.0 and mat[0, 2] == 3.0
assert mat[1, 0] == 4.0 and mat[1, 1] == 5.0 and mat[1, 2] == 6.0
assert mat[2, 0] == 7.0 and mat[2, 1] == 8.0 and mat[2, 2] == 9.0

# __getitem__ by using a slice
# TODO(wilbert): implement __getitem__ to retrieve a view of the vector

0 comments on commit 7dd0887

Please sign in to comment.