From e68beb1bd80a094bb248fa7feea4c0034c73c618 Mon Sep 17 00:00:00 2001 From: pariterre Date: Tue, 2 Jun 2020 15:59:39 -0400 Subject: [PATCH 1/6] Fixed rt declaration error when checking instead of overriding the last row --- pyomeca/processing/rototrans.py | 8 ++++++-- pyomeca/rototrans.py | 12 ++++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/pyomeca/processing/rototrans.py b/pyomeca/processing/rototrans.py index db6061a..4a1fa8a 100644 --- a/pyomeca/processing/rototrans.py +++ b/pyomeca/processing/rototrans.py @@ -154,7 +154,9 @@ def rototrans_from_markers( else: raise ValueError("`axis_to_recalculate must be `x`, `y` or `z`") - rt = caller(np.zeros((4, 4, origin.time.size))) + rt = np.zeros((4, 4, origin.time.size)) + rt[3, 3, :] = 1 + rt = caller(rt) rt[:3, 0, :] = x / np.linalg.norm(x, axis=0) rt[:3, 1, :] = y / np.linalg.norm(y, axis=0) rt[:3, 2, :] = z / np.linalg.norm(z, axis=0) @@ -165,7 +167,9 @@ def rototrans_from_markers( def rototrans_from_transposed_rototrans( caller: Callable, rt: xr.DataArray ) -> xr.DataArray: - rt_t = caller(np.zeros((4, 4, rt.time.size))) + rt_t = np.zeros((4, 4, rt.time.size)) + rt_t[3, 3, :] = 1 + rt_t = caller(rt_t) # the rotation part is just the transposed of the rotation rt_t.meca.rotation = rt.meca.rotation.transpose("col", "row", "time") diff --git a/pyomeca/rototrans.py b/pyomeca/rototrans.py index abb1dcc..2c47ebd 100644 --- a/pyomeca/rototrans.py +++ b/pyomeca/rototrans.py @@ -71,8 +71,16 @@ def __new__( coords["time"] = time # Make sure last line reads [0, 0, 0, 1] - data[3, :3, :] = 0 - data[3, 3, :] = 1 + zeros = data[3, :3, :] + ones = data[3, 3, :] + if not np.alltrue(zeros == 0) or not np.alltrue(ones == 1): + some_zeros = np.random.choice(zeros.ravel(), 5) + some_ones = np.random.choice(ones.ravel(), 5) + raise ValueError( + "Last line does not read [0, 0, 0, 1].\n" + f"Here are some values that should be 0: {some_zeros}\n" + f"And others that should 1: {some_ones}" + ) return xr.DataArray( data=data, From 4babcf788524323eb515d79306dcd313671120b0 Mon Sep 17 00:00:00 2001 From: pariterre Date: Tue, 2 Jun 2020 16:38:20 -0400 Subject: [PATCH 2/6] Fixed the test according to the new API for Rototrans --- tests/data/is_expected_array_val.csv | 2 +- tests/test_object_creation.py | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/data/is_expected_array_val.csv b/tests/data/is_expected_array_val.csv index 10efe63..83da086 100644 --- a/tests/data/is_expected_array_val.csv +++ b/tests/data/is_expected_array_val.csv @@ -65,4 +65,4 @@ idx,shape_val,first_last_val,mean_val,median_val,sum_val,nans_val 64,"(4, 51, 93)","(733.7041, 1.0)",319.803440918195,303.8114,6067310.8811,0 65,"(4, 51, 580)","(44.16278839111328, 1.0)",362.29798490932,337.751922607422,42535594.9182787,915 66,"(1, 38, 11600)","(-0.022051572799682617, 0.0)",-0.001679033595597,0,-740.118008939112,0 -67,"(4, 4, 580)","(744.5535888671875, 1.0)",385.4765160798091,278.7399444580078,3535590.605484009,108 +67,"(4, 4, 580)","(0.306872, 1.0)",107.58354,0.0,975137.207354,216 diff --git a/tests/test_object_creation.py b/tests/test_object_creation.py index 9dbf673..add21de 100644 --- a/tests/test_object_creation.py +++ b/tests/test_object_creation.py @@ -68,7 +68,14 @@ def test_rototrans_creation(): np.testing.assert_array_equal(x=array, y=xr.DataArray(np.eye(4)[..., np.newaxis])) assert array.dims == dims - array = Rototrans(MARKERS_DATA.values, time=MARKERS_DATA.time) + data = Markers(MARKERS_DATA.values) + array = Rototrans.from_markers( + origin=data[:, [0], :], + axis_1=data[:, [0, 1], :], + axis_2=data[:, [0, 2], :], + axes_name="xy", + axis_to_recalculate="y", + ) is_expected_array(array, **EXPECTED_VALUES[67]) size = 4, 4, 100 From 96b4fcfcb9e340ee4719171ac42c3446f1306f28 Mon Sep 17 00:00:00 2001 From: pariterre Date: Tue, 2 Jun 2020 16:45:24 -0400 Subject: [PATCH 3/6] Added a negative test for value error --- tests/test_processing_rt.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_processing_rt.py b/tests/test_processing_rt.py index f5bd045..d23875c 100644 --- a/tests/test_processing_rt.py +++ b/tests/test_processing_rt.py @@ -66,6 +66,12 @@ def test_construct_rt(): with pytest.raises(IndexError): Rototrans(data=np.zeros(1)) + with pytest.raises(ValueError): + Rototrans(data=np.zeros((4, 4, 1))) + + with pytest.raises(ValueError): + Rototrans(data=np.ones((4, 4, 1))) + with pytest.raises(IndexError): Rototrans.from_euler_angles( angles=random_vector[..., :5], From f62327d1037ebf659184b2deae3d7ff191054df7 Mon Sep 17 00:00:00 2001 From: pariterre Date: Tue, 2 Jun 2020 17:30:12 -0400 Subject: [PATCH 4/6] blacked.. --- tests/test_processing_rt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_processing_rt.py b/tests/test_processing_rt.py index d23875c..44e9f3f 100644 --- a/tests/test_processing_rt.py +++ b/tests/test_processing_rt.py @@ -68,7 +68,7 @@ def test_construct_rt(): with pytest.raises(ValueError): Rototrans(data=np.zeros((4, 4, 1))) - + with pytest.raises(ValueError): Rototrans(data=np.ones((4, 4, 1))) From a0d2acc0caab90a59f0346670f50444a69f4922b Mon Sep 17 00:00:00 2001 From: pariterre Date: Wed, 3 Jun 2020 12:39:02 -0400 Subject: [PATCH 5/6] Answered comment --- tests/test_object_creation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_object_creation.py b/tests/test_object_creation.py index add21de..b5c7db1 100644 --- a/tests/test_object_creation.py +++ b/tests/test_object_creation.py @@ -70,9 +70,9 @@ def test_rototrans_creation(): data = Markers(MARKERS_DATA.values) array = Rototrans.from_markers( - origin=data[:, [0], :], - axis_1=data[:, [0, 1], :], - axis_2=data[:, [0, 2], :], + origin=data.isel(channel=[0]), + axis_1=data.isel(channel=[0, 1]), + axis_2=data.isel(channel=[0, 2]), axes_name="xy", axis_to_recalculate="y", ) From 6d6dcfabd1489dc178d5101be65a45e1880b2d8d Mon Sep 17 00:00:00 2001 From: pariterre Date: Wed, 3 Jun 2020 12:40:19 -0400 Subject: [PATCH 6/6] Fixed what seems to be a copy-paste mistake in signature, but not sure, please confirm --- pyomeca/rototrans.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyomeca/rototrans.py b/pyomeca/rototrans.py index 2c47ebd..1579b97 100644 --- a/pyomeca/rototrans.py +++ b/pyomeca/rototrans.py @@ -92,7 +92,7 @@ def __new__( @classmethod def from_random_data( - cls, distribution: str = "normal", size: tuple = (4, 4, 100), **kwargs + cls, distribution: str = "normal", size: tuple = (3, 1, 100), **kwargs ) -> xr.DataArray: """ Create random data from a specified distribution (normal by default) using random walk.