diff --git a/redis_consumer/consumers/base_consumer_test.py b/redis_consumer/consumers/base_consumer_test.py index ed78ebc4..c8a2f44a 100644 --- a/redis_consumer/consumers/base_consumer_test.py +++ b/redis_consumer/consumers/base_consumer_test.py @@ -306,11 +306,8 @@ def lrem(key, count, value): redis_client.lrem = lrem consumer = consumers.Consumer(redis_client, DummyStorage(), queue_name) consumer.get_redis_hash = lambda: '%s:f.tiff:failed' % queue_name - print(redis_client.work_queue) - print(redis_client.processing_queue) + consumer.consume() - print(redis_client.work_queue) - print(redis_client.processing_queue) assert _processed is True _processed = False @@ -320,6 +317,17 @@ def lrem(key, count, value): consumer.consume() assert _processed is True + _processed = 0 + + def G(*_): + global _processed + _processed += 1 + return 'waiting' + + consumer._consume = G + consumer.consume() + assert _processed == N + def test__consume(self): with np.testing.assert_raises(NotImplementedError): consumer = consumers.Consumer(None, None, 'q') diff --git a/redis_consumer/consumers/image_consumer.py b/redis_consumer/consumers/image_consumer.py index 8dec0643..ce9349e9 100644 --- a/redis_consumer/consumers/image_consumer.py +++ b/redis_consumer/consumers/image_consumer.py @@ -33,6 +33,8 @@ import numpy as np +from deepcell_toolbox.utils import tile_image, untile_image + from redis_consumer.consumers import TensorFlowServingConsumer from redis_consumer import utils from redis_consumer import settings @@ -165,8 +167,8 @@ def _consume(self, redis_hash): 'identity_started': self.hostname, }) - cuts = hvals.get('cuts', '0') - field = hvals.get('field_size', '61') + cuts = hvals.get('cuts', '0') # TODO: deprecated + field = hvals.get('field_size', '61') # TODO: deprecated # Overridden with LABEL_DETECT_ENABLED model_name = hvals.get('model_name') @@ -201,8 +203,13 @@ def _consume(self, redis_hash): # Save shape value for postprocessing purposes # TODO this is a big janky self._rawshape = image.shape + label = None + if settings.LABEL_DETECT_ENABLED and model_name and model_version: + self.logger.warning('Label Detection is enabled, but the model' + ' %s:%s was specified in Redis.', + model_name, model_version) - if settings.LABEL_DETECT_ENABLED: + elif settings.LABEL_DETECT_ENABLED: # Detect image label type label = hvals.get('label', '') if not label: @@ -222,16 +229,64 @@ def _consume(self, redis_hash): # Send data to the model self.update_key(redis_hash, {'status': 'predicting'}) - if streaming: - image = self.process_big_image( - cuts, image, field, model_name, model_version) + model_shape = settings.MODEL_SIZES.get( + '{}:{}'.format(model_name, model_version), max(image.shape)) + + if (image.shape[image.ndim - 3] < model_shape or + image.shape[image.ndim - 2] < model_shape): + # tiling not necessary, but image must be padded. + pad_width = [] + for i in range(image.ndim): + if i in {image.ndim - 3, image.ndim - 2}: + diff = model_shape - image.shape[i] + if diff % 2: + pad_width.append((diff // 2, diff // 2 + 1)) + else: + pad_width.append((diff // 2, diff // 2)) + else: + pad_width.append((0, 0)) + padded_img = np.pad(image, pad_width, 'reflect') + image = self.grpc_image(padded_img, model_name, model_version) + + for i, j in enumerate(image): + + self.logger.critical('output %s shape is %s', i, j.shape) + + # unpad results + pad_width.insert(0, (0, 0)) # batch size + if isinstance(image, list): + image = [utils.unpad_image(i, pad_width) for i in image] + else: + image = utils.unpad_image(image, pad_width) + + elif (image.shape[image.ndim - 3] > model_shape or + image.shape[image.ndim - 2] > model_shape): + # need to tile! + tiles, tiles_info = tile_image( + np.expand_dims(image, axis=0), + model_input_shape=(model_shape, model_shape), + stride_ratio=0.75) + + # max_batch_size is 1 by default. + # dependent on the tf-serving configuration + results = [] + for t in range(tiles.shape[0]): + output = self.grpc_image(tiles[t], model_name, model_version) + if not results: + results = output + else: + for i, o in enumerate(output): + results[i] = np.vstack((results[i], o)) + + image = [untile_image(r, tiles_info) for r in results] + else: image = self.grpc_image(image, model_name, model_version) # Post-process model results self.update_key(redis_hash, {'status': 'post-processing'}) - if settings.LABEL_DETECT_ENABLED: + if settings.LABEL_DETECT_ENABLED and label is not None: post_funcs = utils._pick_postprocess(label).split(',') else: post_funcs = hvals.get('postprocess_function', '').split(',') diff --git a/redis_consumer/consumers/image_consumer_test.py b/redis_consumer/consumers/image_consumer_test.py index 135fbd3a..bc3bb493 100644 --- a/redis_consumer/consumers/image_consumer_test.py +++ b/redis_consumer/consumers/image_consumer_test.py @@ -295,3 +295,27 @@ def grpc_image_list(data, *args, **kwargs): # pylint: disable=W0613 consumer.grpc_image = grpc_image_list result = consumer._consume(dummyhash) assert result == consumer.final_status + + settings.LABEL_DETECT_ENABLED = True + settings.SCALE_DETECT_ENABLED = True + + # test with model_name and model_version + redis_client.hgetall = lambda x: { + 'model_name': 'model', + 'model_version': '0', + 'label': '0', + 'scale': '1', + 'postprocess_function': '', + 'preprocess_function': '', + 'file_name': 'test_image.tiff', + 'input_file_name': 'test_image.tiff', + 'output_file_name': 'test_image.tiff' + } + redis_client.hmset = lambda x, y: True + consumer = consumers.ImageFileConsumer(redis_client, storage, prefix) + consumer._handle_error = _handle_error + consumer.detect_scale = detect_scale + consumer.detect_label = detect_label + consumer.grpc_image = grpc_image + result = consumer._consume(dummyhash) + assert result == consumer.final_status diff --git a/redis_consumer/consumers/tracking_consumer_test.py b/redis_consumer/consumers/tracking_consumer_test.py index e43e377c..90ddcfdc 100644 --- a/redis_consumer/consumers/tracking_consumer_test.py +++ b/redis_consumer/consumers/tracking_consumer_test.py @@ -37,6 +37,7 @@ import pytest from redis_consumer import consumers +from redis_consumer import settings from redis_consumer import utils @@ -207,6 +208,23 @@ def test_is_valid_hash(self): assert consumer.is_valid_hash('track:1234567890:file.trk') is True assert consumer.is_valid_hash('track:1234567890:file.trks') is True + def test__get_tracker(self): + queue = 'track' + items = ['item%s' % x for x in range(1, 4)] + + storage = DummyStorage() + redis_client = DummyRedis(items) + redis_client.hget = lambda *x: x[0] + + shape = (5, 21, 21, 1) + raw = np.random.random(shape) + segmented = np.random.randint(1, 10, size=shape) + + settings.NORMALIZE_TRACKING = True + + consumer = consumers.TrackingConsumer(redis_client, storage, queue) + consumer._get_tracker('item1', {}, raw, segmented) + def test__consume(self): queue = 'track' items = ['item%s' % x for x in range(1, 4)] diff --git a/redis_consumer/processing.py b/redis_consumer/processing.py index d5bf1081..149ff7a1 100644 --- a/redis_consumer/processing.py +++ b/redis_consumer/processing.py @@ -23,17 +23,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Functions for pre- and post-processing image data""" +"""DEPRECATED. Please use the "deepell_toolbox" package instead. + +Functions for pre- and post-processing image data +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +# pylint: disable=W0611 + from deepcell_toolbox import normalize from deepcell_toolbox import mibi from deepcell_toolbox import watershed from deepcell_toolbox import pixelwise from deepcell_toolbox import correct_drift +from deepcell_toolbox.deep_watershed import deep_watershed + from deepcell_toolbox import retinanet_semantic_to_label_image from deepcell_toolbox import retinanet_to_label_image diff --git a/redis_consumer/settings.py b/redis_consumer/settings.py index 1c8ce0e9..7829140d 100644 --- a/redis_consumer/settings.py +++ b/redis_consumer/settings.py @@ -123,7 +123,8 @@ def _strip(x): 'mibi': processing.mibi, 'watershed': processing.watershed, 'retinanet': processing.retinanet_to_label_image, - 'retinanet-semantic': processing.retinanet_semantic_to_label_image + 'retinanet-semantic': processing.retinanet_semantic_to_label_image, + 'deep_watershed': processing.deep_watershed, }, } @@ -170,9 +171,19 @@ def _strip(x): 2: CYTOPLASM_MODEL } -PHASE_POSTPROCESS = config('PHASE_POSTPROCESS', default='retinanet-semantic', cast=str) -CYTOPLASM_POSTPROCESS = config('CYTOPLASM_POSTPROCESS', default='retinanet-semantic', cast=str) -NUCLEAR_POSTPROCESS = config('NUCLEAR_POSTPROCESS', default='retinanet', cast=str) +PHASE_POSTPROCESS = config('PHASE_POSTPROCESS', default='deep_watershed', cast=str) +CYTOPLASM_POSTPROCESS = config('CYTOPLASM_POSTPROCESS', default='deep_watershed', cast=str) +NUCLEAR_POSTPROCESS = config('NUCLEAR_POSTPROCESS', default='deep_watershed', cast=str) + +PHASE_RESHAPE_SIZE = config('PHASE_RESHAPE_SIZE', default=512, cast=int) +CYTOPLASM_RESHAPE_SIZE = config('CYTOPLASM_RESHAPE_SIZE', default=512, cast=int) +NUCLEAR_RESHAPE_SIZE = config('NUCLEAR_RESHAPE_SIZE', default=512, cast=int) + +MODEL_SIZES = { + NUCLEAR_MODEL: NUCLEAR_RESHAPE_SIZE, + PHASE_MODEL: PHASE_RESHAPE_SIZE, + CYTOPLASM_MODEL: CYTOPLASM_RESHAPE_SIZE, +} POSTPROCESS_CHOICES = { 0: NUCLEAR_POSTPROCESS, diff --git a/redis_consumer/utils.py b/redis_consumer/utils.py index ecd25a49..818b059e 100644 --- a/redis_consumer/utils.py +++ b/redis_consumer/utils.py @@ -240,6 +240,23 @@ def pad_image(image, field): return np.pad(image, pad_width, mode='reflect') +def unpad_image(x, pad_width): + """Unpad image padded with the pad_width. + + Args: + image (numpy.array): Image to unpad. + pad_width (list): List of pads used to pad the image with np.pad. + + Returns: + numpy.array: The unpadded image. + """ + slices = [] + for c in pad_width: + e = None if c[1] == 0 else -c[1] + slices.append(slice(c[0], e)) + return x[tuple(slices)] + + def save_numpy_array(arr, name='', subdir='', output_dir=None): """Split tensor into channels and save each as a tiff file. diff --git a/redis_consumer/utils_test.py b/redis_consumer/utils_test.py index a3488c81..7b4e2dee 100644 --- a/redis_consumer/utils_test.py +++ b/redis_consumer/utils_test.py @@ -182,6 +182,27 @@ def test_pad_image(): np.testing.assert_equal(padded.shape, (frames, new_h, new_w, 1)) +def test_unpad_image(): + # 2D images + h, w = 330, 330 + padded = _get_image(h, w) + pad_width = [(15, 15), (15, 15), (0, 0)] + + new_h = h - (pad_width[0][0] + pad_width[0][1]) + new_w = w - (pad_width[1][0] + pad_width[1][1]) + + unpadded = utils.unpad_image(padded, pad_width) + np.testing.assert_equal(unpadded.shape, (new_h, new_w, 1)) + + # 3D images + frames = np.random.randint(low=1, high=6) + imgs = np.vstack([_get_image(h, w)[None, ...] for i in range(frames)]) + + pad_width = [(0, 0), (15, 15), (15, 15), (0, 0)] + unpadded = utils.unpad_image(imgs, pad_width) + np.testing.assert_equal(unpadded.shape, (frames, new_h, new_w, 1)) + + def test_save_numpy_array(): h, w = 30, 30 c = np.random.randint(low=1, high=4) diff --git a/requirements.txt b/requirements.txt index 22f65acc..4fa7ea13 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,4 @@ pytz==2019.1 keras_retinanet==0.5.1 opencv-python==4.1.0.25 deepcell-tracking==0.2.5 -deepcell-toolbox==0.1.0 +deepcell-toolbox==0.2.0