From 8e04e37c0a9fef53d71f9b1ab576422d2520b08f Mon Sep 17 00:00:00 2001 From: Ziga Luksic Date: Wed, 9 Nov 2022 13:04:53 +0100 Subject: [PATCH] add axis parameter to MergeFeatureTask --- core/eolearn/core/core_tasks.py | 8 ++++---- core/eolearn/tests/test_core_tasks.py | 9 +++++---- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/core/eolearn/core/core_tasks.py b/core/eolearn/core/core_tasks.py index 6df8bb974..8b56f802d 100644 --- a/core/eolearn/core/core_tasks.py +++ b/core/eolearn/core/core_tasks.py @@ -508,11 +508,11 @@ def zip_method(self, *f): class MergeFeatureTask(ZipFeatureTask): - """Merges multiple features together by concatenating their data along the last axis.""" + """Merges multiple features together by concatenating their data along the specified axis.""" - def zip_method(self, *f: np.ndarray, dtype: Union[None, np.dtype, type] = None) -> np.ndarray: - """Concatenates the data of features along the last axis.""" - return np.concatenate(f, axis=-1, dtype=dtype) # pylint: disable=unexpected-keyword-arg + def zip_method(self, *f: np.ndarray, dtype: Union[None, np.dtype, type] = None, axis: int = -1) -> np.ndarray: + """Concatenates the data of features along the specified axis.""" + return np.concatenate(f, axis=axis, dtype=dtype) # pylint: disable=unexpected-keyword-arg class ExtractBandsTask(MapFeatureTask): diff --git a/core/eolearn/tests/test_core_tasks.py b/core/eolearn/tests/test_core_tasks.py index f8a9dec59..9536180e1 100644 --- a/core/eolearn/tests/test_core_tasks.py +++ b/core/eolearn/tests/test_core_tasks.py @@ -308,7 +308,8 @@ def test_move_feature(): assert "MTless2" in patch_dst[FeatureType.MASK_TIMELESS] -def test_merge_features(): +@pytest.mark.parametrize("axis", (0, -1)) +def test_merge_features(axis): patch = EOPatch() shape = (10, 5, 5, 3) @@ -332,10 +333,10 @@ def test_merge_features(): for feat, dat in zip(features, data): patch = AddFeatureTask(feat)(patch, dat) - patch = MergeFeatureTask(features[:3], (FeatureType.MASK, "merged"))(patch) - patch = MergeFeatureTask(features[3:], (FeatureType.MASK_TIMELESS, "merged_timeless"))(patch) + patch = MergeFeatureTask(features[:3], (FeatureType.MASK, "merged"), axis=axis)(patch) + patch = MergeFeatureTask(features[3:], (FeatureType.MASK_TIMELESS, "merged_timeless"), axis=axis)(patch) - expected = np.concatenate([patch[f] for f in features[:3]], axis=-1) + expected = np.concatenate([patch[f] for f in features[:3]], axis=axis) assert np.array_equal(patch.mask["merged"], expected)