Skip to content

Commit

Permalink
Validate inputs for cropping (#97)
Browse files Browse the repository at this point in the history
* check crop_size arguments

* pep8

* check for floats

* lists or tuples
  • Loading branch information
ngreenwald committed Jun 10, 2020
1 parent 23d9e2a commit c9ddef7
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 5 deletions.
37 changes: 32 additions & 5 deletions caliban_toolbox/reshape_data.py
Expand Up @@ -54,12 +54,39 @@ def crop_multichannel_data(X_data, y_data, crop_size=None, crop_num=None, overla
"""

# sanitize inputs
if len(crop_size) != 2:
raise ValueError('crop_size must be a tuple of (row_crop, col_crop), '
'got {}'.format(crop_size))
if crop_size is None and crop_num is None:
raise ValueError('Either crop_size or crop_num must be specified')

if not crop_size[0] > 0 and crop_size[1] > 0:
raise ValueError('crop_size entries must be positive numbers')
if crop_size is not None and crop_num is not None:
raise ValueError('Only one of crop_size and crop_num should be provided')

if crop_size is not None:
if not isinstance(crop_size, (tuple, list)):
raise ValueError('crop_size must be a tuple or list')

if len(crop_size) != 2:
raise ValueError('crop_size must be a tuple of (row_crop, col_crop), '
'got {}'.format(crop_size))

if not crop_size[0] > 0 and crop_size[1] > 0:
raise ValueError('crop_size entries must be positive')

if not isinstance(crop_size[0], int) and isinstance(crop_size[1], int):
raise ValueError('crop_size entries must be integers')

if crop_num is not None:
if not isinstance(crop_num, (tuple, list)):
raise ValueError('crop_num must be a tuple or list')

if len(crop_num) != 2:
raise ValueError('crop_num must be a tuple of (num_row, num_col), '
'got {}'.format(crop_size))

if not crop_num[0] > 0 and crop_num[1] > 0:
raise ValueError('crop_num entries must be positive')

if not isinstance(crop_num[0], int) and isinstance(crop_num[1], int):
raise ValueError('crop_num entries must be integers')

if overlap_frac < 0 or overlap_frac > 1:
raise ValueError('overlap_frac must be between 0 and 1')
Expand Down
56 changes: 56 additions & 0 deletions caliban_toolbox/reshape_data_test.py
Expand Up @@ -25,6 +25,7 @@
# ==============================================================================
import os
import tempfile
import pytest

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -65,6 +66,61 @@ def test_crop_multichannel_data():

assert log_data["num_crops"] == expected_crop_num

# invalid arguments

# no crop_size or crop_num
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data)

# both crop_size and crop_num
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data,
crop_size=(20, 20), crop_num=(20, 20))
# bad crop_size dtype
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data,
crop_size=5)
# bad crop_size shape
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data,
crop_size=(10, 5, 2))
# bad crop_size values
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data,
crop_size=(0, 5))
# bad crop_size values
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data,
crop_size=(1.5, 5))
# bad crop_num dtype
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data,
crop_num=5)
# bad crop_num shape
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data,
crop_num=(10, 5, 2))
# bad crop_num values
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data,
crop_num=(0, 5))
# bad crop_num values
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data,
crop_num=(1.5, 5))
# bad overlap_frac value
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data,
overlap_frac=1.2)
# bad X_data dims
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data[0], y_data=test_y_data,
crop_size=(5, 5))
# bad y_data dims
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data[0],
crop_num=(5, 5))


def test_create_slice_data():
# test output shape with even division of slice
Expand Down

0 comments on commit c9ddef7

Please sign in to comment.