From c9ddef72c8e9903c5e6a6a93c072550f73005617 Mon Sep 17 00:00:00 2001 From: ngreenwald Date: Tue, 9 Jun 2020 22:10:36 -0700 Subject: [PATCH] Validate inputs for cropping (#97) * check crop_size arguments * pep8 * check for floats * lists or tuples --- caliban_toolbox/reshape_data.py | 37 +++++++++++++++--- caliban_toolbox/reshape_data_test.py | 56 ++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 5 deletions(-) diff --git a/caliban_toolbox/reshape_data.py b/caliban_toolbox/reshape_data.py index 7a1eba5..77cb33b 100644 --- a/caliban_toolbox/reshape_data.py +++ b/caliban_toolbox/reshape_data.py @@ -54,12 +54,39 @@ def crop_multichannel_data(X_data, y_data, crop_size=None, crop_num=None, overla """ # sanitize inputs - if len(crop_size) != 2: - raise ValueError('crop_size must be a tuple of (row_crop, col_crop), ' - 'got {}'.format(crop_size)) + if crop_size is None and crop_num is None: + raise ValueError('Either crop_size or crop_num must be specified') - if not crop_size[0] > 0 and crop_size[1] > 0: - raise ValueError('crop_size entries must be positive numbers') + if crop_size is not None and crop_num is not None: + raise ValueError('Only one of crop_size and crop_num should be provided') + + if crop_size is not None: + if not isinstance(crop_size, (tuple, list)): + raise ValueError('crop_size must be a tuple or list') + + if len(crop_size) != 2: + raise ValueError('crop_size must be a tuple of (row_crop, col_crop), ' + 'got {}'.format(crop_size)) + + if not crop_size[0] > 0 and crop_size[1] > 0: + raise ValueError('crop_size entries must be positive') + + if not isinstance(crop_size[0], int) and isinstance(crop_size[1], int): + raise ValueError('crop_size entries must be integers') + + if crop_num is not None: + if not isinstance(crop_num, (tuple, list)): + raise ValueError('crop_num must be a tuple or list') + + if len(crop_num) != 2: + raise ValueError('crop_num must be a tuple of (num_row, num_col), ' + 'got {}'.format(crop_size)) + + if not crop_num[0] > 0 and crop_num[1] > 0: + raise ValueError('crop_num entries must be positive') + + if not isinstance(crop_num[0], int) and isinstance(crop_num[1], int): + raise ValueError('crop_num entries must be integers') if overlap_frac < 0 or overlap_frac > 1: raise ValueError('overlap_frac must be between 0 and 1') diff --git a/caliban_toolbox/reshape_data_test.py b/caliban_toolbox/reshape_data_test.py index 0215cb2..9149abe 100644 --- a/caliban_toolbox/reshape_data_test.py +++ b/caliban_toolbox/reshape_data_test.py @@ -25,6 +25,7 @@ # ============================================================================== import os import tempfile +import pytest import numpy as np import xarray as xr @@ -65,6 +66,61 @@ def test_crop_multichannel_data(): assert log_data["num_crops"] == expected_crop_num + # invalid arguments + + # no crop_size or crop_num + with pytest.raises(ValueError): + _ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data) + + # both crop_size and crop_num + with pytest.raises(ValueError): + _ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data, + crop_size=(20, 20), crop_num=(20, 20)) + # bad crop_size dtype + with pytest.raises(ValueError): + _ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data, + crop_size=5) + # bad crop_size shape + with pytest.raises(ValueError): + _ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data, + crop_size=(10, 5, 2)) + # bad crop_size values + with pytest.raises(ValueError): + _ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data, + crop_size=(0, 5)) + # bad crop_size values + with pytest.raises(ValueError): + _ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data, + crop_size=(1.5, 5)) + # bad crop_num dtype + with pytest.raises(ValueError): + _ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data, + crop_num=5) + # bad crop_num shape + with pytest.raises(ValueError): + _ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data, + crop_num=(10, 5, 2)) + # bad crop_num values + with pytest.raises(ValueError): + _ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data, + crop_num=(0, 5)) + # bad crop_num values + with pytest.raises(ValueError): + _ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data, + crop_num=(1.5, 5)) + # bad overlap_frac value + with pytest.raises(ValueError): + _ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data, + overlap_frac=1.2) + # bad X_data dims + with pytest.raises(ValueError): + _ = reshape_data.crop_multichannel_data(X_data=test_X_data[0], y_data=test_y_data, + crop_size=(5, 5)) + # bad y_data dims + with pytest.raises(ValueError): + _ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data[0], + crop_num=(5, 5)) + def test_create_slice_data(): # test output shape with even division of slice