Skip to content

Commit

Permalink
Merge 3c2aa2f into 5880577
Browse files Browse the repository at this point in the history
  • Loading branch information
ngreenwald committed Apr 26, 2020
2 parents 5880577 + 3c2aa2f commit 7c3f1aa
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 4 deletions.
20 changes: 17 additions & 3 deletions deepcell_toolbox/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def untile_image(tiles, tiles_info,
return image


def resize(data, shape, data_format='channels_last'):
def resize(data, shape, data_format='channels_last', data_type='X'):
"""Resize the data to the given shape.
Uses openCV to resize the data if the data is a single channel, as it
is very fast. However, openCV does not support multi-channel resizing,
Expand All @@ -279,6 +279,8 @@ def resize(data, shape, data_format='channels_last'):
Batch and channel dimensions are handled automatically and preserved.
data_format (str): determines the order of the channel axis,
one of 'channels_first' and 'channels_last'.
data_type (str): determines whether resizing raw data or labels
one of 'X' or 'y'
Raises:
ValueError: ndim of data not 3 or 4
Expand Down Expand Up @@ -309,12 +311,24 @@ def resize(data, shape, data_format='channels_last'):
else:
shape = tuple(list(shape) + [data.shape[channel_axis]])

_resize = lambda d: transform.resize(d, shape, mode='constant', preserve_range=True)
# linear interpolation (order 1) for image data, near neighbor (order 0) for labels
order = 1 if data_type == 'X' else 0

# anti_aliasing introduces spurious labels, include only for image data
anti_aliasing = data_type == 'X'

_resize = lambda d: transform.resize(d, shape, mode='constant', preserve_range=True,
order=order, anti_aliasing=anti_aliasing)
# single channel image, resize with cv2
else:
shape = tuple(shape)

_resize = lambda d: np.expand_dims(cv2.resize(np.squeeze(d), shape), axis=channel_axis)
# linear interpolation for image data, nearest neighbor for labels
interpolation = cv2.INTER_LINEAR if data_type == 'X' else cv2.INTER_NEAREST

_resize = lambda d: np.expand_dims(cv2.resize(np.squeeze(d), shape,
interpolation=interpolation),
axis=channel_axis)

# Check for batch dimension to loop over
if len(data.shape) == 4:
Expand Down
21 changes: 20 additions & 1 deletion deepcell_toolbox/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def test_resize():

for out in out_shapes:
for c in channel_sizes:

# batch, channel first
in_shape = [c] + base_shape + [4]
out_shape = tuple([c] + out + [4])
Expand All @@ -193,6 +192,26 @@ def test_resize():
rs = utils.resize(np.random.rand(*in_shape), out, data_format='channels_last')
assert out_shape == rs.shape

# make sure label data is not linearly interpolated and returns only the same ints

# no batch, channel last
in_shape = base_shape + [c]
out_shape = tuple(out + [c])
in_data = np.random.choice(a=[0, 1, 9, 20], size=in_shape, replace=True)
rs = utils.resize(in_data, out, data_format='channels_last', data_type='y')
assert out_shape == rs.shape
assert np.all(rs == np.floor(rs))
assert np.all(np.unique(rs) == [0, 1, 9, 20])

# batch, channel first
in_shape = [c] + base_shape + [4]
out_shape = tuple([c] + out + [4])
in_data = np.random.choice(a=[0, 1, 9, 20], size=in_shape, replace=True)
rs = utils.resize(in_data, out, data_format='channels_first', data_type='y')
assert out_shape == rs.shape
assert np.all(rs == np.floor(rs))
assert np.all(np.unique(rs) == [0, 1, 9, 20])

# Wrong data size
with pytest.raises(ValueError):
im = np.random.rand(20, 20)
Expand Down

0 comments on commit 7c3f1aa

Please sign in to comment.