Skip to content

Commit

Permalink
get_data_window cleanup (#2510) (#2559)
Browse files Browse the repository at this point in the history
* Cleaner get_data_window implementation.

* Add tests.

* Separate assertions into distinct tests.

Co-authored-by: Ryan Grout <groutr@users.noreply.github.com>
  • Loading branch information
snowman2 and groutr committed Aug 19, 2022
1 parent ee75596 commit 570b169
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 30 deletions.
55 changes: 29 additions & 26 deletions rasterio/windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,36 +154,39 @@ def get_data_window(arr, nodata=None):
-------
Window
"""

num_dims = len(arr.shape)
if num_dims > 3:
if not 0 < arr.ndim <=3 :
raise WindowError(
"get_data_window input array must have no more than "
"3 dimensions")

if nodata is None:
if not hasattr(arr, 'mask'):
return Window.from_slices((0, arr.shape[-2]), (0, arr.shape[-1]))
"get_data_window input array must have 1, 2, or 3 dimensions")

# If nodata is defined, construct mask from that value
# Otherwise retrieve mask from array (if it is masked)
# Finally try returning a full window (nodata=None and nothing in arr is masked)
if nodata is not None:
arr_mask = arr != nodata
elif np.ma.is_masked(arr):
arr_mask = ~np.ma.getmask(arr)
else:
arr = np.ma.masked_array(arr, arr == nodata)

if num_dims == 2:
data_rows, data_cols = np.where(np.equal(arr.mask, False))
else:
data_rows, data_cols = np.where(
np.any(np.equal(np.rollaxis(arr.mask, 0, 3), False), axis=2))

if data_rows.size:
row_range = (data_rows.min(), data_rows.max() + 1)
else:
row_range = (0, 0)
if arr.ndim == 1:
full_window = ((0, arr.size), (0, 0))
else:
full_window = ((0, arr.shape[-2]), (0, arr.shape[-1]))
return Window.from_slices(*full_window)

if data_cols.size:
col_range = (data_cols.min(), data_cols.max() + 1)
else:
col_range = (0, 0)
if arr.ndim == 3:
arr_mask = np.any(arr_mask, axis=0)

return Window.from_slices(row_range, col_range)
# We only have 1 or 2 dimension cases to process
v = []
for nz in arr_mask.nonzero():
if nz.size:
v.append((nz.min(), nz.max() + 1))
else:
v.append((0, 0))

if arr_mask.ndim == 1:
v.append((0, 0))

return Window.from_slices(*v)


def _compute_union(w1, w2):
Expand Down
25 changes: 21 additions & 4 deletions tests/test_windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,20 +469,37 @@ def test_read_with_window_class(path_rgb_byte_tif):
assert subset.shape == (10, 10)


def test_data_window_invalid_arr_dims():
def test_data_window_invalid_4d():
"""An array of more than 3 dimensions is invalid."""
arr = np.ones((3, 3, 3, 3))
# Test > 3 dims
with pytest.raises(WindowError):
get_data_window(arr)
get_data_window(np.ones((3, 3, 3, 3)))


def test_data_window_full():
def test_data_window_invalid_0d():
"""An array of less than 1 dimension is invalid."""
# Test < 1 dim
with pytest.raises(WindowError):
get_data_window(np.ones(()))


def test_data_window_full_2d():
"""Get window of entirely valid data array."""
arr = np.ones((3, 3))
window = get_data_window(arr)
assert window == Window.from_slices((0, 3), (0, 3))


def test_data_window_full_1d():
window = get_data_window(np.ones(3))
assert window == Window.from_slices((0, 3), (0, 0))


def test_data_window_full_3d():
window = get_data_window(np.ones((3, 3, 3)))
assert window == Window.from_slices((0, 3), (0, 3))


def test_data_window_nodata():
"""Get window of arr with nodata."""
arr = np.ones((3, 3))
Expand Down

0 comments on commit 570b169

Please sign in to comment.