Skip to content

Commit

Permalink
Modify workflow to use the new PanopticNet deep watershed models. (#88)
Browse files Browse the repository at this point in the history
* bump deepcell-toolbox to 0.2.0

* add deep_watershed to processing and update postprocess env vars.

* add default model sizes as environment variables

* Use tile/untile if image is too big, or pad/unpad if image is too small.

* Just warn if label detect is enabled AND the model is specified. use the specified model.

* add tests for tracking_consumer.get_tracker, base_consumer._put_back_hash, and a _conumser test with detections enabled.
  • Loading branch information
willgraf committed Mar 2, 2020
1 parent d499cb1 commit 412eef1
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 17 deletions.
16 changes: 12 additions & 4 deletions redis_consumer/consumers/base_consumer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,8 @@ def lrem(key, count, value):
redis_client.lrem = lrem
consumer = consumers.Consumer(redis_client, DummyStorage(), queue_name)
consumer.get_redis_hash = lambda: '%s:f.tiff:failed' % queue_name
print(redis_client.work_queue)
print(redis_client.processing_queue)

consumer.consume()
print(redis_client.work_queue)
print(redis_client.processing_queue)
assert _processed is True

_processed = False
Expand All @@ -320,6 +317,17 @@ def lrem(key, count, value):
consumer.consume()
assert _processed is True

_processed = 0

def G(*_):
global _processed
_processed += 1
return 'waiting'

consumer._consume = G
consumer.consume()
assert _processed == N

def test__consume(self):
with np.testing.assert_raises(NotImplementedError):
consumer = consumers.Consumer(None, None, 'q')
Expand Down
69 changes: 62 additions & 7 deletions redis_consumer/consumers/image_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

import numpy as np

from deepcell_toolbox.utils import tile_image, untile_image

from redis_consumer.consumers import TensorFlowServingConsumer
from redis_consumer import utils
from redis_consumer import settings
Expand Down Expand Up @@ -165,8 +167,8 @@ def _consume(self, redis_hash):
'identity_started': self.hostname,
})

cuts = hvals.get('cuts', '0')
field = hvals.get('field_size', '61')
cuts = hvals.get('cuts', '0') # TODO: deprecated
field = hvals.get('field_size', '61') # TODO: deprecated

# Overridden with LABEL_DETECT_ENABLED
model_name = hvals.get('model_name')
Expand Down Expand Up @@ -201,8 +203,13 @@ def _consume(self, redis_hash):
# Save shape value for postprocessing purposes
# TODO this is a big janky
self._rawshape = image.shape
label = None
if settings.LABEL_DETECT_ENABLED and model_name and model_version:
self.logger.warning('Label Detection is enabled, but the model'
' %s:%s was specified in Redis.',
model_name, model_version)

if settings.LABEL_DETECT_ENABLED:
elif settings.LABEL_DETECT_ENABLED:
# Detect image label type
label = hvals.get('label', '')
if not label:
Expand All @@ -222,16 +229,64 @@ def _consume(self, redis_hash):
# Send data to the model
self.update_key(redis_hash, {'status': 'predicting'})

if streaming:
image = self.process_big_image(
cuts, image, field, model_name, model_version)
model_shape = settings.MODEL_SIZES.get(
'{}:{}'.format(model_name, model_version), max(image.shape))

if (image.shape[image.ndim - 3] < model_shape or
image.shape[image.ndim - 2] < model_shape):
# tiling not necessary, but image must be padded.
pad_width = []
for i in range(image.ndim):
if i in {image.ndim - 3, image.ndim - 2}:
diff = model_shape - image.shape[i]
if diff % 2:
pad_width.append((diff // 2, diff // 2 + 1))
else:
pad_width.append((diff // 2, diff // 2))
else:
pad_width.append((0, 0))
padded_img = np.pad(image, pad_width, 'reflect')
image = self.grpc_image(padded_img, model_name, model_version)

for i, j in enumerate(image):

self.logger.critical('output %s shape is %s', i, j.shape)

# unpad results
pad_width.insert(0, (0, 0)) # batch size
if isinstance(image, list):
image = [utils.unpad_image(i, pad_width) for i in image]
else:
image = utils.unpad_image(image, pad_width)

elif (image.shape[image.ndim - 3] > model_shape or
image.shape[image.ndim - 2] > model_shape):
# need to tile!
tiles, tiles_info = tile_image(
np.expand_dims(image, axis=0),
model_input_shape=(model_shape, model_shape),
stride_ratio=0.75)

# max_batch_size is 1 by default.
# dependent on the tf-serving configuration
results = []
for t in range(tiles.shape[0]):
output = self.grpc_image(tiles[t], model_name, model_version)
if not results:
results = output
else:
for i, o in enumerate(output):
results[i] = np.vstack((results[i], o))

image = [untile_image(r, tiles_info) for r in results]

else:
image = self.grpc_image(image, model_name, model_version)

# Post-process model results
self.update_key(redis_hash, {'status': 'post-processing'})

if settings.LABEL_DETECT_ENABLED:
if settings.LABEL_DETECT_ENABLED and label is not None:
post_funcs = utils._pick_postprocess(label).split(',')
else:
post_funcs = hvals.get('postprocess_function', '').split(',')
Expand Down
24 changes: 24 additions & 0 deletions redis_consumer/consumers/image_consumer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,27 @@ def grpc_image_list(data, *args, **kwargs): # pylint: disable=W0613
consumer.grpc_image = grpc_image_list
result = consumer._consume(dummyhash)
assert result == consumer.final_status

settings.LABEL_DETECT_ENABLED = True
settings.SCALE_DETECT_ENABLED = True

# test with model_name and model_version
redis_client.hgetall = lambda x: {
'model_name': 'model',
'model_version': '0',
'label': '0',
'scale': '1',
'postprocess_function': '',
'preprocess_function': '',
'file_name': 'test_image.tiff',
'input_file_name': 'test_image.tiff',
'output_file_name': 'test_image.tiff'
}
redis_client.hmset = lambda x, y: True
consumer = consumers.ImageFileConsumer(redis_client, storage, prefix)
consumer._handle_error = _handle_error
consumer.detect_scale = detect_scale
consumer.detect_label = detect_label
consumer.grpc_image = grpc_image
result = consumer._consume(dummyhash)
assert result == consumer.final_status
18 changes: 18 additions & 0 deletions redis_consumer/consumers/tracking_consumer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import pytest

from redis_consumer import consumers
from redis_consumer import settings
from redis_consumer import utils


Expand Down Expand Up @@ -207,6 +208,23 @@ def test_is_valid_hash(self):
assert consumer.is_valid_hash('track:1234567890:file.trk') is True
assert consumer.is_valid_hash('track:1234567890:file.trks') is True

def test__get_tracker(self):
queue = 'track'
items = ['item%s' % x for x in range(1, 4)]

storage = DummyStorage()
redis_client = DummyRedis(items)
redis_client.hget = lambda *x: x[0]

shape = (5, 21, 21, 1)
raw = np.random.random(shape)
segmented = np.random.randint(1, 10, size=shape)

settings.NORMALIZE_TRACKING = True

consumer = consumers.TrackingConsumer(redis_client, storage, queue)
consumer._get_tracker('item1', {}, raw, segmented)

def test__consume(self):
queue = 'track'
items = ['item%s' % x for x in range(1, 4)]
Expand Down
9 changes: 8 additions & 1 deletion redis_consumer/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Functions for pre- and post-processing image data"""
"""DEPRECATED. Please use the "deepell_toolbox" package instead.
Functions for pre- and post-processing image data
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=W0611

from deepcell_toolbox import normalize
from deepcell_toolbox import mibi
from deepcell_toolbox import watershed
from deepcell_toolbox import pixelwise
from deepcell_toolbox import correct_drift

from deepcell_toolbox.deep_watershed import deep_watershed

from deepcell_toolbox import retinanet_semantic_to_label_image
from deepcell_toolbox import retinanet_to_label_image

Expand Down
19 changes: 15 additions & 4 deletions redis_consumer/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def _strip(x):
'mibi': processing.mibi,
'watershed': processing.watershed,
'retinanet': processing.retinanet_to_label_image,
'retinanet-semantic': processing.retinanet_semantic_to_label_image
'retinanet-semantic': processing.retinanet_semantic_to_label_image,
'deep_watershed': processing.deep_watershed,
},
}

Expand Down Expand Up @@ -170,9 +171,19 @@ def _strip(x):
2: CYTOPLASM_MODEL
}

PHASE_POSTPROCESS = config('PHASE_POSTPROCESS', default='retinanet-semantic', cast=str)
CYTOPLASM_POSTPROCESS = config('CYTOPLASM_POSTPROCESS', default='retinanet-semantic', cast=str)
NUCLEAR_POSTPROCESS = config('NUCLEAR_POSTPROCESS', default='retinanet', cast=str)
PHASE_POSTPROCESS = config('PHASE_POSTPROCESS', default='deep_watershed', cast=str)
CYTOPLASM_POSTPROCESS = config('CYTOPLASM_POSTPROCESS', default='deep_watershed', cast=str)
NUCLEAR_POSTPROCESS = config('NUCLEAR_POSTPROCESS', default='deep_watershed', cast=str)

PHASE_RESHAPE_SIZE = config('PHASE_RESHAPE_SIZE', default=512, cast=int)
CYTOPLASM_RESHAPE_SIZE = config('CYTOPLASM_RESHAPE_SIZE', default=512, cast=int)
NUCLEAR_RESHAPE_SIZE = config('NUCLEAR_RESHAPE_SIZE', default=512, cast=int)

MODEL_SIZES = {
NUCLEAR_MODEL: NUCLEAR_RESHAPE_SIZE,
PHASE_MODEL: PHASE_RESHAPE_SIZE,
CYTOPLASM_MODEL: CYTOPLASM_RESHAPE_SIZE,
}

POSTPROCESS_CHOICES = {
0: NUCLEAR_POSTPROCESS,
Expand Down
17 changes: 17 additions & 0 deletions redis_consumer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,23 @@ def pad_image(image, field):
return np.pad(image, pad_width, mode='reflect')


def unpad_image(x, pad_width):
"""Unpad image padded with the pad_width.
Args:
image (numpy.array): Image to unpad.
pad_width (list): List of pads used to pad the image with np.pad.
Returns:
numpy.array: The unpadded image.
"""
slices = []
for c in pad_width:
e = None if c[1] == 0 else -c[1]
slices.append(slice(c[0], e))
return x[tuple(slices)]


def save_numpy_array(arr, name='', subdir='', output_dir=None):
"""Split tensor into channels and save each as a tiff file.
Expand Down
21 changes: 21 additions & 0 deletions redis_consumer/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,27 @@ def test_pad_image():
np.testing.assert_equal(padded.shape, (frames, new_h, new_w, 1))


def test_unpad_image():
# 2D images
h, w = 330, 330
padded = _get_image(h, w)
pad_width = [(15, 15), (15, 15), (0, 0)]

new_h = h - (pad_width[0][0] + pad_width[0][1])
new_w = w - (pad_width[1][0] + pad_width[1][1])

unpadded = utils.unpad_image(padded, pad_width)
np.testing.assert_equal(unpadded.shape, (new_h, new_w, 1))

# 3D images
frames = np.random.randint(low=1, high=6)
imgs = np.vstack([_get_image(h, w)[None, ...] for i in range(frames)])

pad_width = [(0, 0), (15, 15), (15, 15), (0, 0)]
unpadded = utils.unpad_image(imgs, pad_width)
np.testing.assert_equal(unpadded.shape, (frames, new_h, new_w, 1))


def test_save_numpy_array():
h, w = 30, 30
c = np.random.randint(low=1, high=4)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ pytz==2019.1
keras_retinanet==0.5.1
opencv-python==4.1.0.25
deepcell-tracking==0.2.5
deepcell-toolbox==0.1.0
deepcell-toolbox==0.2.0

0 comments on commit 412eef1

Please sign in to comment.